From 8f1bf236bf614ddd803915afc231867e7b957e21 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Wed, 28 Apr 2021 16:05:16 -0700 Subject: [PATCH 01/20] finish soft actor critic --- pl_bolts/models/rl/common/agents.py | 45 +++ pl_bolts/models/rl/common/distributions.py | 69 ++++ pl_bolts/models/rl/common/networks.py | 69 ++++ pl_bolts/models/rl/sac_model.py | 440 +++++++++++++++++++++ 4 files changed, 623 insertions(+) create mode 100644 pl_bolts/models/rl/common/distributions.py create mode 100644 pl_bolts/models/rl/sac_model.py diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index 6a8a2895c5..92478ba1b4 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -138,3 +138,48 @@ def __call__(self, states: torch.Tensor, device: str) -> List[int]: actions = [np.random.choice(len(prob), p=prob) for prob in prob_np] return actions + +class SoftActorCriticAgent(Agent): + """Actor-Critic based agent that returns a continuous action based on the policy""" + def __call__(self, states: torch.Tensor, device: str) -> List[float]: + """ + Takes in the current state and returns the action based on the agents policy + + Args: + states: current state of the environment + device: the device used for the current batch + + Returns: + action defined by policy + """ + if not isinstance(states, list): + states = [states] + + if not isinstance(states, torch.Tensor): + states = torch.tensor(states, device=device) + + dist = self.net(states) + actions = [a for a in dist.sample().cpu().numpy()] + + return actions + + def get_action(self, states: torch.Tensor, device: str) -> List[float]: + """ + Get the action greedily (without sampling) + + Args: + states: current state of the environment + device: the device used for the current batch + + Returns: + action defined by policy + """ + if not isinstance(states, list): + states = [states] + + if not isinstance(states, torch.Tensor): + states = torch.tensor(states, device=device) + + actions = [self.net.get_action(states).cpu().numpy()] + + return actions diff --git a/pl_bolts/models/rl/common/distributions.py b/pl_bolts/models/rl/common/distributions.py new file mode 100644 index 0000000000..2230b84c5d --- /dev/null +++ b/pl_bolts/models/rl/common/distributions.py @@ -0,0 +1,69 @@ +""" +Distributions used in some continuous RL algorithms +""" +import torch + + +class TanhMultivariateNormal(torch.distributions.MultivariateNormal): + """ + The distribution of X is an affine of tanh applied on a normal distribution + X = action_scale * tanh(Z) + action_bias + Z ~ Normal(mean, variance) + """ + def __init__(self, action_bias, action_scale, **kwargs): + super().__init__(**kwargs) + + self.action_bias = action_bias + self.action_scale = action_scale + + def rsample_with_z(self, sample_shape=torch.Size()): + """ + Samples X using reparametrization trick with the intermediate variable Z + + Returns: + Sampled X and Z + """ + z = super().rsample() + return self.action_scale * torch.tanh(z) + self.action_bias, z + + def log_prob_with_z(self, value, z): + """ + Computes the log probability of a sampled X + + Refer to the original paper of SAC for more details in equation (20), (21) + + Args: + value: the value of X + z: the value of Z + Returns: + Log probability of the sample + """ + value = (value - self.action_bias) / self.action_scale + z_logprob = super().log_prob(z) + correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) + return z_logprob - correction + + def rsample_and_log_prob(self, sample_shape=torch.Size()): + """ + Samples X and computes the log probability of the sample + + Returns: + Sampled X and log probability + """ + z = super().rsample() + z_logprob = super().log_prob(z) + value = torch.tanh(z) + correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) + return self.action_scale * value + self.action_bias, z_logprob - correction + + """ + Some override methods + """ + def rsample(self, sample_shape=torch.Size()): + fz, z = self.rsample_with_z(sample_shape) + return fz + + def log_prob(self, value): + value = (value - self.action_bias) / self.action_scale + z = torch.log(1 + value) / 2 - torch.log(1-value) / 2 + return self.log_prob_with_z(value, z) diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index 2a47433797..74d2439dc8 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -9,6 +9,8 @@ from torch import nn, Tensor from torch.nn import functional as F +from pl_bolts.models.rl.common.distributions import TanhMultivariateNormal + class CNN(nn.Module): """ @@ -92,6 +94,73 @@ def forward(self, input_x): return self.net(input_x.float()) +class ContinuousMLP(nn.Module): + """ + MLP network that outputs continuous value via Gaussian distribution + """ + def __init__( + self, + input_shape: Tuple[int], + n_actions: int, + hidden_size: int = 128, + action_bias: int = 0, + action_scale: int = 1 + ): + """ + Args: + input_shape: observation shape of the environment + n_actions: dimension of actions in the environment + hidden_size: size of hidden layers + action_bias: the center of the action space + action_scale: the scale of the action space + """ + super(ContinuousMLP, self).__init__() + self.action_bias = action_bias + self.action_scale = action_scale + + self.shared_net = nn.Sequential( + nn.Linear(input_shape[0], hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.ReLU() + ) + self.mean_layer = nn.Linear(hidden_size, n_actions) + self.logstd_layer = nn.Linear(hidden_size, n_actions) + + def forward(self, x: torch.FloatTensor) -> TanhMultivariateNormal: + """ + Forward pass through network. Calculates the action distribution + + Args: + x: input to network + Returns: + action distribution + """ + x = self.shared_net(x.float()) + batch_mean = self.mean_layer(x) + logstd = torch.clamp(self.logstd_layer(x), -20, 2) + batch_scale_tril = torch.diag_embed(torch.exp(logstd)) + return TanhMultivariateNormal( + action_bias=self.action_bias, + action_scale=self.action_scale, + loc=batch_mean, + scale_tril=batch_scale_tril + ) + + def get_action(self, x: torch.FloatTensor) -> torch.Tensor: + """ + Get the action greedily (without sampling) + + Args: + x: input to network + Returns: + mean action + """ + x = self.shared_net(x.float()) + batch_mean = self.mean_layer(x) + return self.action_scale * torch.tanh(batch_mean) + self.action_bias + + class DuelingMLP(nn.Module): """ MLP network with duel heads for val and advantage diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py new file mode 100644 index 0000000000..ef5689d238 --- /dev/null +++ b/pl_bolts/models/rl/sac_model.py @@ -0,0 +1,440 @@ +""" +Soft Actor Critic +""" +import argparse +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from torch import optim as optim +from torch.optim.optimizer import Optimizer +from torch.nn import functional as F +from torch.utils.data import DataLoader + +from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset +from pl_bolts.losses.rl import dqn_loss +from pl_bolts.models.rl.common.agents import SoftActorCriticAgent +from pl_bolts.models.rl.common.gym_wrappers import make_environment +from pl_bolts.models.rl.common.memory import MultiStepBuffer +from pl_bolts.models.rl.common.networks import ContinuousMLP, MLP +from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _GYM_AVAILABLE: + import gym +else: # pragma: no cover + warn_missing_pkg('gym') + Env = object + + +class SAC(pl.LightningModule): + def __init__( + self, + env: str, + eps_start: float = 1.0, + eps_end: float = 0.02, + eps_last_frame: int = 150000, + sync_rate: int = 1, + gamma: float = 0.99, + policy_learning_rate: float = 3e-4, + q_learning_rate: float = 3e-4, + target_alpha: float = 5e-3, + batch_size: int = 256, + replay_size: int = 100000, + warm_start_size: int = 1000000, + avg_reward_len: int = 100, + min_episode_reward: int = -21, + seed: int = 123, + batches_per_epoch: int = 1000, + n_steps: int = 1, + **kwargs, + ): + super().__init__() + + # Environment + self.env = gym.make(env) + self.test_env = gym.make(env) + + self.obs_shape = self.env.observation_space.shape + self.n_actions = self.env.action_space.shape[0] + + # Model Attributes + self.buffer = None + self.dataset = None + + self.policy = None + self.q1 = None + self.q2 = None + self.target_q1 = None + self.target_q2 = None + self.build_networks() + + self.agent = SoftActorCriticAgent(self.policy) + + # Hyperparameters + self.sync_rate = sync_rate + self.gamma = gamma + self.batch_size = batch_size + self.replay_size = replay_size + self.warm_start_size = warm_start_size + self.batches_per_epoch = batches_per_epoch + self.n_steps = n_steps + + self.save_hyperparameters() + + # Metrics + self.total_episode_steps = [0] + self.total_rewards = [0] + self.done_episodes = 0 + self.total_steps = 0 + + # Average Rewards + self.avg_reward_len = avg_reward_len + + for _ in range(avg_reward_len): + self.total_rewards.append(torch.tensor(min_episode_reward, device=self.device)) + + self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:])) + + self.state = self.env.reset() + + self.automatic_optimization = False + + def run_n_episodes(self, env, n_epsiodes: int = 1) -> List[int]: + """ + Carries out N episodes of the environment with the current agent without exploration + + Args: + env: environment to use, either train environment or test environment + n_epsiodes: number of episodes to run + """ + total_rewards = [] + + for _ in range(n_epsiodes): + episode_state = env.reset() + done = False + episode_reward = 0 + + while not done: + action = self.agent.get_action(episode_state, self.device) + next_state, reward, done, _ = env.step(action[0]) + episode_state = next_state + episode_reward += reward + + total_rewards.append(episode_reward) + + return total_rewards + + def populate(self, warm_start: int) -> None: + """Populates the buffer with initial experience""" + if warm_start > 0: + self.state = self.env.reset() + + for _ in range(warm_start): + action = self.agent(self.state, self.device) + next_state, reward, done, _ = self.env.step(action[0]) + exp = Experience( + state=self.state, + action=action[0], + reward=reward, + done=done, + new_state=next_state + ) + self.buffer.append(exp) + self.state = next_state + + if done: + self.state = self.env.reset() + print("done populating") + + def build_networks(self) -> None: + """Initializes the SAC policy and q networks (with targets)""" + action_bias = torch.from_numpy((self.env.action_space.high + self.env.action_space.low) / 2) + action_scale = torch.from_numpy((self.env.action_space.high - self.env.action_space.low) / 2) + self.policy = ContinuousMLP( + self.obs_shape, + self.n_actions, + action_bias=action_bias, + action_scale=action_scale + ) + + concat_shape = [self.obs_shape[0] + self.n_actions] + self.q1 = MLP(concat_shape, 1) + self.q2 = MLP(concat_shape, 1) + self.target_q1 = MLP(concat_shape, 1) + self.target_q2 = MLP(concat_shape, 1) + self.target_q1.load_state_dict(self.q1.state_dict()) + self.target_q2.load_state_dict(self.q2.state_dict()) + + def soft_update_target(self, q_net, target_net): + """ + Update the weights in target network using a weighted sum + w_target := (1-a) * w_target + a * w_q + + Args: + q_net: the critic (q) network + target_net: the target (q) network + """ + for q_param, target_param in zip(q_net.parameters(), target_net.parameters()): + target_param.data.copy_( + (1.0 - self.hparams.target_alpha) * target_param.data + + self.hparams.target_alpha * q_param + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Passes in a state x through the network and gets the q_values of each action as an output + + Args: + x: environment state + + Returns: + q values + """ + output = self.policy(x).sample() + return output + + def train_batch(self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Contains the logic for generating a new batch of data to be passed to the DataLoader + + Returns: + yields a Experience tuple containing the state, action, reward, done and next_state. + """ + episode_reward = 0 + episode_steps = 0 + + while True: + self.total_steps += 1 + action = self.agent(self.state, self.device) + + next_state, r, is_done, _ = self.env.step(action[0]) + + episode_reward += r + episode_steps += 1 + + exp = Experience(state=self.state, action=action[0], reward=r, done=is_done, new_state=next_state) + + self.buffer.append(exp) + self.state = next_state + + if is_done: + self.done_episodes += 1 + self.total_rewards.append(episode_reward) + self.total_episode_steps.append(episode_steps) + self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:])) + self.state = self.env.reset() + episode_steps = 0 + episode_reward = 0 + + states, actions, rewards, dones, new_states = self.buffer.sample(self.batch_size) + + for idx, _ in enumerate(dones): + yield states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx] + + # Simulates epochs + if self.total_steps % self.batches_per_epoch == 0: + break + + def loss( + self, + batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Calculates the loss for SAC which contains a total of 3 losses + + Args: + batch: a batch of states, actions, rewards, dones, and next states + """ + states, actions, rewards, dones, next_states = batch + rewards = rewards.unsqueeze(-1) + dones = dones.float().unsqueeze(-1) + + # actor + dist = self.policy(states) + new_actions, new_logprobs = dist.rsample_and_log_prob() + new_logprobs = new_logprobs.unsqueeze(-1) + + new_states_actions = torch.cat((states, new_actions), 1) + new_q1_values = self.q1(new_states_actions) + new_q2_values = self.q2(new_states_actions) + new_qmin_values = torch.min(new_q1_values, new_q2_values) + + policy_loss = (new_logprobs - new_qmin_values).mean() + + # critic + states_actions = torch.cat((states, actions), 1) + q1_values = self.q1(states_actions) + q2_values = self.q2(states_actions) + + with torch.no_grad(): + next_dist = self.policy(next_states) + new_next_actions, new_next_logprobs = next_dist.rsample_and_log_prob() + new_next_logprobs = new_next_logprobs.unsqueeze(-1) + + new_next_states_actions = torch.cat((next_states, new_next_actions), 1) + next_q1_values = self.target_q1(new_next_states_actions) + next_q2_values = self.target_q2(new_next_states_actions) + next_qmin_values = torch.min(next_q1_values, next_q2_values) - new_next_logprobs + target_values = rewards + (1. - dones) * self.gamma * next_qmin_values + + q1_loss = F.mse_loss(q1_values, target_values) + q2_loss = F.mse_loss(q2_values, target_values) + + return policy_loss, q1_loss, q2_loss + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _, optimizer_idx): + """ + Carries out a single step through the environment to update the replay buffer. + Then calculates loss based on the minibatch recieved + + Args: + batch: current mini batch of replay data + _: batch number, not used + optimizer_idx: not used + """ + policy_optim, q1_optim, q2_optim = self.optimizers() + policy_loss, q1_loss, q2_loss = self.loss(batch) + + policy_optim.zero_grad() + self.manual_backward(policy_loss) + policy_optim.step() + + q1_optim.zero_grad() + self.manual_backward(q1_loss) + q1_optim.step() + + q2_optim.zero_grad() + self.manual_backward(q2_loss) + q2_optim.step() + + # Soft update of target network + if self.global_step % self.sync_rate == 0: + self.soft_update_target(self.q1, self.target_q1) + self.soft_update_target(self.q2, self.target_q2) + + self.log_dict({ + "total_reward": self.total_rewards[-1], + "avg_reward": self.avg_rewards, + "policy_loss": policy_loss, + "q1_loss": q1_loss, + "q2_loss": q2_loss, + "episodes": self.done_episodes, + "episode_steps": self.total_episode_steps[-1] + }) + + def test_step(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + """Evaluate the agent for 10 episodes""" + test_reward = self.run_n_episodes(self.test_env, 1) + avg_reward = sum(test_reward) / len(test_reward) + return {"test_reward": avg_reward} + + def test_epoch_end(self, outputs) -> Dict[str, torch.Tensor]: + """Log the avg of the test results""" + rewards = [x["test_reward"] for x in outputs] + avg_reward = sum(rewards) / len(rewards) + self.log("avg_test_reward", avg_reward) + return {"avg_test_reward": avg_reward} + + def _dataloader(self) -> DataLoader: + """Initialize the Replay Buffer dataset used for retrieving experiences""" + self.buffer = MultiStepBuffer(self.replay_size, self.n_steps) + self.populate(self.warm_start_size) + + self.dataset = ExperienceSourceDataset(self.train_batch) + return DataLoader(dataset=self.dataset, batch_size=self.batch_size) + + def train_dataloader(self) -> DataLoader: + """Get train loader""" + return self._dataloader() + + def test_dataloader(self) -> DataLoader: + """Get test loader""" + return self._dataloader() + + def configure_optimizers(self) -> Tuple[Optimizer]: + """ Initialize Adam optimizer""" + policy_optim = optim.Adam(self.policy.parameters(), self.hparams.policy_learning_rate) + q1_optim = optim.Adam(self.q1.parameters(), self.hparams.q_learning_rate) + q2_optim = optim.Adam(self.q2.parameters(), self.hparams.q_learning_rate) + return policy_optim, q1_optim, q2_optim + + @staticmethod + def add_model_specific_args(arg_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: + """ + Adds arguments for DQN model + + Note: + These params are fine tuned for Pong env. + + Args: + arg_parser: parent parser + """ + arg_parser.add_argument( + "--sync_rate", + type=int, + default=1, + help="how many frames do we update the target network", + ) + arg_parser.add_argument( + "--replay_size", + type=int, + default=100000, + help="capacity of the replay buffer", + ) + arg_parser.add_argument( + "--warm_start_size", + type=int, + default=10000, + help="how many samples do we use to fill our buffer at the start of training", + ) + arg_parser.add_argument("--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch") + arg_parser.add_argument("--batch_size", type=int, default=128, help="size of the batches") + arg_parser.add_argument("--policy_lr", type=float, default=3e-4, help="policy learning rate") + arg_parser.add_argument("--q_lr", type=float, default=3e-4, help="q learning rate") + arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag") + arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + + arg_parser.add_argument( + "--avg_reward_len", + type=int, + default=100, + help="how many episodes to include in avg reward", + ) + arg_parser.add_argument( + "--n_steps", + type=int, + default=1, + help="how many frames do we update the target network", + ) + + return arg_parser + + +def cli_main(): + parser = argparse.ArgumentParser(add_help=False) + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = SAC.add_model_specific_args(parser) + args = parser.parse_args() + + model = SAC(**args.__dict__) + + # save checkpoints based on avg_reward + checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True) + + seed_everything(123) + trainer = pl.Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback) + + trainer.fit(model) + + +if __name__ == '__main__': + cli_main() From 8c2145febb0a2a9d315ca7bba3bfa01b6728efc2 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Wed, 28 Apr 2021 17:02:40 -0700 Subject: [PATCH 02/20] added tests --- .../integration/test_actor_critic_models.py | 36 +++++++++++ tests/models/rl/test_scripts.py | 15 +++++ tests/models/rl/unit/test_sac.py | 59 +++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 tests/models/rl/integration/test_actor_critic_models.py create mode 100644 tests/models/rl/unit/test_sac.py diff --git a/tests/models/rl/integration/test_actor_critic_models.py b/tests/models/rl/integration/test_actor_critic_models.py new file mode 100644 index 0000000000..f3a88465c6 --- /dev/null +++ b/tests/models/rl/integration/test_actor_critic_models.py @@ -0,0 +1,36 @@ +import argparse + +import pytorch_lightning as pl + +from pl_bolts.models.rl.sac_model import SAC + + +def test_sac(): + """Smoke test that the SAC model runs""" + + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = pl.Trainer.add_argparse_args(parent_parser) + parent_parser = SAC.add_model_specific_args(parent_parser) + args_list = [ + "--warm_start_size", + "100", + "--gpus", + "0", + "--env", + "Pendulum-v0", + "--batch_size", + "10", + ] + hparams = parent_parser.parse_args(args_list) + + trainer = pl.Trainer( + gpus=hparams.gpus, + max_steps=100, + max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early + val_check_interval=1, # This just needs 'some' value, does not effect training right now + fast_dev_run=True + ) + model = SAC(**hparams.__dict__) + result = trainer.fit(model) + + assert result == 1 diff --git a/tests/models/rl/test_scripts.py b/tests/models/rl/test_scripts.py index ee30206718..f70d862e19 100644 --- a/tests/models/rl/test_scripts.py +++ b/tests/models/rl/test_scripts.py @@ -126,3 +126,18 @@ def test_cli_run_rl_vanilla_policy_gradient(cli_args): cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() + + +@pytest.mark.parametrize('cli_args', [ + ' --env Pendulum-v0' + ' --max_steps 10' + ' --fast_dev_run 1' + ' --batch_size 10', +]) +def test_cli_run_rl_soft_actor_critic(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.sac_model import cli_main + + cli_args = cli_args.strip().split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() diff --git a/tests/models/rl/unit/test_sac.py b/tests/models/rl/unit/test_sac.py new file mode 100644 index 0000000000..2c3f86aae7 --- /dev/null +++ b/tests/models/rl/unit/test_sac.py @@ -0,0 +1,59 @@ +import argparse + +import torch + +from pl_bolts.models.rl.sac_model import SAC + + +def test_sac_loss(): + """Test the reinforce loss function""" + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = SAC.add_model_specific_args(parent_parser) + args_list = [ + "--env", + "Pendulum-v0", + "--batch_size", + "32", + ] + hparams = parent_parser.parse_args(args_list) + model = SAC(**vars(hparams)) + + batch_states = torch.rand(32, 3) + batch_actions = torch.rand(32, 1) + batch_rewards = torch.rand(32) + batch_dones = torch.ones(32) + batch_next_states = torch.rand(32, 3) + batch = (batch_states, batch_actions, batch_rewards, batch_dones, batch_next_states) + + policy_loss, q1_loss, q2_loss = model.loss(batch) + + assert isinstance(policy_loss, torch.Tensor) + assert isinstance(q1_loss, torch.Tensor) + assert isinstance(q2_loss, torch.Tensor) + + +def test_sac_train_batch(): + """Tests that a single batch generates correctly""" + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = SAC.add_model_specific_args(parent_parser) + args_list = [ + "--env", + "Pendulum-v0", + "--batch_size", + "32", + ] + hparams = parent_parser.parse_args(args_list) + model = SAC(**vars(hparams)) + + xp_dataloader = model.train_dataloader() + + batch = next(iter(xp_dataloader)) + + assert len(batch) == 5 + assert len(batch[0]) == model.hparams.batch_size + assert isinstance(batch, list) + assert isinstance(batch[0], torch.Tensor) + assert isinstance(batch[1], torch.Tensor) + assert isinstance(batch[2], torch.Tensor) + assert isinstance(batch[3], torch.Tensor) + assert isinstance(batch[4], torch.Tensor) From 0c872a12e98fe9cc9a109154d622e2a729f4b600 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Fri, 30 Apr 2021 20:49:41 -0700 Subject: [PATCH 03/20] finish document and init --- .../rl_benchmark/pendulum_sac_results.jpg | Bin 0 -> 56862 bytes docs/source/reinforce_learn.rst | 86 ++++++++++++++++++ pl_bolts/models/rl/__init__.py | 2 + pl_bolts/models/rl/sac_model.py | 11 +-- 4 files changed, 93 insertions(+), 6 deletions(-) create mode 100644 docs/source/_images/rl_benchmark/pendulum_sac_results.jpg diff --git a/docs/source/_images/rl_benchmark/pendulum_sac_results.jpg b/docs/source/_images/rl_benchmark/pendulum_sac_results.jpg new file mode 100644 index 0000000000000000000000000000000000000000..77e91fab9832f5d09ef33990d3ee517a9d0b9043 GIT binary patch literal 56862 zcmdSB2|SeT_b`6XjD6pCjY768g^*=P2uT!EgzRMB*Nl`UOG=BPC?!fv%2LSKMMd_a zFeHS`SO;U~f9YB4`ToDp^L^j%=l8zv{XF;F$8}%JIoEa0xz0J)xud?PjzSzJhQ@{v z35mmp^!0M|hf%`IpDz+ecX+bS0i3-;fFV*$()2oMDD)h*cl7r1Z>cKQXb z`CjLkr9P!j=`z{|J zzb*I%fTclwe*o73Sl#JQeY*b%c64<5y?;kX_ut`P=mISP5(hkk{Jk9`zW?#x{PFP) z0dx2L+6jJHJOU3|f+r7{!}nf(rdu!zfMrQSB`-I`em%)5HG8(@~!?CJY0=`?H|Dw>2VAV%9$VtArb0ny(JxN z09FX{I=Y1qzzv}l>7l<>{=H6Mh~*ak5}@4D)z@f0fB`LtP`5zCE!wy0C;J2e=sWF* zJ070KTW!HSAf5($96JbLFh7VUXUFgO16mO8Tz&O^ofAZ;YtZp6+(3PJg^TL}0{{cs z5mTYg-!Y;J{;B=`4Isl15A>ej`{Gij2XUA;iEcok}H-?ZS6bQLP zu8{n%CBL;me}Cf(evd+t&>Z9o>ID7XZvXeUu23i_-~7Gwm$ved+xNGj-&=S=U%&`H zkSXL1{tf}KD`@lEXx7kYP~!Rf-@lFS2>Nse_;1nj|E=_ontFiVL%+5S`E5io z_TSpm8PFY|J4k0pCqpAdqeG)fa|rzEf_o2*I*rC}bo^1D_6uzn?HFwz?I`UB2hX^` zpXSI68U+;nGKa2Ue*ZKVG&@0`sxQloDcqWs z-&W+$H2qEi^v7!aT8~kQ5zffZsLLq*b4xZ^Hr3zg{XW-!p2=Tub^3k2f5YIf5U1EEb z_UwkBt-WTe4g@XoecwZ1eC>af#VtXQwhNU?t@)$u5Rmn zp#&%e?8euiEGQSc2Ngoa&{L=qs)6dE7U(tf4(fx30QXElGY|nX)5BK!tC5B?Zl0k4O@hWEio z;WO}M1cYEl@FBJ%{O)5pd~{NDYCslG(4D4>qD!T_O^2ncqwAs@r(33{rx&D`qt~W4r+1_e zrca>1N&lE0N8e39K~G{}W z1Y;a?7aRi|%beVt zyE)A{eK=D&A91#EPI1w5?c_4xa^X79b&so&Yn&UwEzYga?aY0S`#yIw_XH0ej}(tF z&uN|%o?@O39s(~HuL`d^8n_y4y}~ySlA*+cc6Jse|-HUPa=N zU&VOE^u)ZyZiqFB5w?qLKe#<)``zsw+c$Pd?XcPrw*$N5lQ@&Oy11M8Rq+OKf&@~+ zOd?vMNMcZuQBqyfLo!pcRdRKw)K2uy#GREpr=$d=jHM!_iljbDvr40+1Elk$du8Zk z)MdP6Zp*xrh0ChSddOzWcFMu!_R4w6-InXxMY~I5m+!8;T?6tg@_O>2@q#+#wb5m zCaTD(xT@r+yjSH`HB(JgZBX6ZtF||AZ^_JHK#G(C9v;DE8Hv9s|*V}gmA$r+O-Qx?;srdg(Ahhz^09;z~fo0*th zG5dH}^04pWN^{8E*!+t5rz28F0*+K$&{-U@$hP=mx!W?@vh^tUQHP_Cj;Z8*7^fHp{jKw%2SYPO6?v zIN5I}WfyMOW-ny#ZU54N-NDh}*(usn=u?kQksZw)^Bk9*jGb;d5uEj%Go5E#bX~5x zOuK5kX1GqeX}e{(O}T5kUvZ!IKzUsEnDx~6%R#4VjepQ}1AdOr62c)U@3Q36{+V8TG6W@2sKC!Tk`FnQ5XrCC*8tx#Q1v%Mz!CGX4hTBh3bwNzXrZn@6CZnoaN ze!Rh+FrD)w^zR2`?}(d>YMVns&C6X z_I6Zss&zhpr}3_;ORKB48`WLkbD*cC*SPm}-{HQleyjfX11ASY-@Ck@`r!S6I2by( zF%w@KeI8LFvmG}>znkqvU%P4H-zJa31T3Tx{$gkxLB~HvedAAWO;bS zXJunGnIuRmTvK0bT}Q7^ZiJEP$u~A-HmfM6l#f&&D%IaH(DA!1u;q8d-GG%CR{6a0(6xb*_q5KswtY?<-;brAFf+&g1I-?|XA zWx^lRh2YA{zu%x_>k4vH-t!%Ru>=CsghZt-ib4>56$EVtQmGrasnpH;z@PX4L3IJY z_@G<0>43j+M|TVA6L5*$dj9$P#XkilztQ;f6Hu>la_1y zZaq3~3rG4L0X%!oUcSp9zQ44d*RqQsq3jeG$H>ITFCZu+xl>A7Mpi|2ubR5XKK%m* zhDHaCO^zP3vPJ__&e_G)&E3QEbWm_eXjphe>nG#?7p}d-wAn z6cj#u^z>O-c}3;(7gY_7P0cN>uiDzXdwTo&2i|`e#E*@Cp7=63HT{*iu(-6mvPxQ8 z-=Yf!A$}q2cglXEixbcVr=da6&~MQNgNK0|!AV2AeK#GKo&~*Q0QZhPXBl|*U%p#f z&nT{JN#Jz~>|)}RP{B(Qw@CX=**_yJ?th80-wFGJt|4G{|3VcU2C#+0!BQc>1uVU- zi;f-`Yv2M#+4qa#`^C6*0UPSqiwZh{ZM8%oXuL(B^+C?3NSP(G4 zObAX01yLx{dCAcKOE>l1e}FT#9`|08!lb&PAJc^+p#26jS9X#Ts*VsG>X&;uVtstF zYh{FlyH?M9QsQqC#<1WVy|UF$$DE8>_@Q#D=Bk z$kW!@eCPbw%1EQ6v3F#B?AOnTUGfsc=(a@JhhI$2I)8L>^GC1yVMpaCOcvwV4^(Ky znhKputufxvr9$ZgB^j2MuCzF}YmG_oJ->iAjzhufx?;3Ia zs31!!r0qpvRNX+6uB-$AT;87m?bRs^le>ceA1j!j9iN|KW9IfK^oDsua;{#bEOh%FSUANns1$%%vz(UgH>s zMh6-`nK#(frr=+$tL)?1w_)U5biJ30Irw0yM7OY&t0yBzp%eS7Tr*64F9lA*Z)~S@ z>o5T-;$vP181qgxV;Q|1=*WlUyDr==*Noun5!P|mx@^2X=^QXG2m!i+xzE(dJgqCa z1oW)lh99OrczJ|tCbBjY8-9c2g-0cZ*iL`xE}P1kmd8I&ZbvIQDVE!24kn)*I)rQS z>Zj0S>cugm`8raY3B{<`E3ayExfWD@3ThU8u)$ym9gaSl!1X0x|OBATXrKRE!Sk!`N+ymVL81;{EtD zMki~tk5qe%-JsD|`S#7qh)L~KD0Zp=cio(%GTsuyPH-o4rZ=LuDfK;Gx-m&STw7ZH zc^dwLTVa2uYumQIO$kiHB1w`8>0e9)WHF#f6A(reML0I6yh7Y8z>JPLFjDGAp3%3a zb4QC2s8Di@aDB)l)BKT-k2!1)tUf*Yn)Y#E-kC=RX`HS{h12I_8FmAN4^mX49Ra#- z;c}x=WRCgK$vOHn_?PBm+Ra~z$8f1RdwLCCQZC==jF7vXxyK|<&FPt}+O71Y3`y5L z-yHI4{0fhsrb2{p_Qu<2;~dQ*+w^mRZ9MOugnttBBecK1q&KVA?2jS@X3pL$BKfUU zkvORk?a_f+K1{x}uu@Zjb31nj^2qXqko6 zHD9yyJWz58xqOMzOWsR*O{`vV0@Ico?YOaBAXZ2K8Rbb*=fQYvU4~#A7Ms^*$2(X3eLtbVCuE0kh3NxO-GJsUEoER z@9BeowNzNQVo zlZ?lNU-Hh2Tb|gpyy0dQR!{F_%$ORehzrKCrkR8t^4BFFSc0FB7G8aSF;mWaO z?yzRu(@e6kIl;VjgmsqEOf+2yt?AN{zX?Z{85Pi1<6L^e&_jOWV_1d`G!8K4Sp*g?KW#J zZXLW@)TR>1{~@69DZ~2={0)i_;C)$~cPB~GXjX>c)q^d5BKQwsu~Mx&2H_LmCDPvu3e(t=BVSbC$nxNlft?}WT*mISYzt-qNN^9jSy01 zt{)}kd|Pdt6HG8uoU<`#m#;RoVu~=Qx$GnReD~$EsYmD^V(MWyr_Q0HEA=FLLIH)z zXTyflXTR?yVc@)K)X-b@2F#!uKVKzZq}kF~Xz!OS?>DYD-S-QMo`M&q6AUg1%)L)Shz#I&=t=aSJAx3DZAqKE_}FmB17$p z`)!r8ZI^k(umdQP>gxU;vN@PVev&P|*iP%U*#XjnL08IBkeAw2Bavoctd-1_32-AU zQdnnzGHJq)BynV(#RCIN%v6Yll(li3P*GVhondD)lXd3(0agaF>IK6PmK3Sd~O+gs>2w4N-(pwLi=h{U_qwB{Zi8;_iF)7#q611 zTuMi7Rp;KILf+T5P^+&&i7q1tF%)LA4Y!zgS&8u>=Cs(9$HBKvh7R`UDUH;-tHchL zy{hx^ytDA7u03|#OQ{h#7Z=MUo%etIMp3(eF+yY$^7E&&nje<-&E-{{P1>n(X&;vc zJqyDawo8`*LLg#gJn~|-&aK1}+?zU+a5T~AQPepdSsi+!eQEy3K~EhqzSrTwLobdz z4ukW{J{4$tMRU<>&6jiqEDT#d=>oylYBMcb9}mocp4Vfd@??#|XmmwQxSwlJ!BLi{ zDz@hy&k65V6urcpVS&IGtrPNMQFJ;yThohsvL92Ais7a-V9IskrJ7RA6ATwN72}t` zU7I{TSaj3R4c7;@L9M{bq6}`sC+IoLWeqqm@3w1hThNur;}XCgIO>tY*Y-%nn4(8Dw|kuX$9?-Pm&!S7L_u|h;?-t&3V@mzK)G^JSp zB=i`DY=tD0>L97myD%Vib%kKddIZgSlxR};q#CKvI@{JJ#tYu_8j*LI1Jh$k6ewMS zEQ%riK+koHP)=bS_T$?8gNU#Gv#)CGxM` z1Iv<`8%W(IeJ~XwGo=G_>keB>@;DCf6b)wk&bM4DbQV-1T1`3-Dr=NMbq`Vyg-e)y zc78U!Y<3afnmg&!F+2Hfr9-f12aVb4{236jn&>nC zbk6^E|NoHjv8;sE<`EWq(neFH_I%PVp^i51w@+{4VbQ%#B(cw*&-UIz-=RYP1ESST zb|< z8(9Ho6=QCi9i=gGA->5h@qpxEW{wC^hJtXTvGllQMJCUzi}$e8a{&RKEE!hC!on+f6L%^b4dW!WJkvo8Gk#k|r9!rH73gU57!T_8MK2cTcaBA~H_JP;*6j>YrS zI3s!g*{s;_*OJHQk{sMnvom(M_t1FTA@@0?$?71HNQESqsL=ab)Y5}kZF&l)IWy+n z-T?}^fCM&&A0{X!83Nq<=PCGGv?9`**_p##eS)>CLJB*|PqA%3oa5g8vZ4kzmPwAB z!mJ2Cp+a3(Vef&u`VS^D{f8yDT{=QX|HHE1`!gBVn>;XlE=`R*K9gj7UwOL!{*93< zOqfo?73{Ge$>D8^24+VF8Wf4 zEH}{kKe5s0T5(M4WCCVZnVph+{pFtjjH&)AYHxJ5lhBRKqpSg$X5}=VtLi!?n9jW` z1tdg6t$-igBX!W{;m(F3%JBm7qQmJS(gh%JRLG7BJvlQ#NjZ%H*4!G$lKDEJ0?fXb zE)a_UEy+jxnGMjHWM!}vHsei0Wiv<9d`DRtZtpdkvmddyCA_BSL<7^(u{wPdE(^{6 zH;1*3zj!TwItus9C$DlOZP=2gW1l->%Qb-VlBE&UzJIS{bG~J{d=cP#Z zb%ZZ_bjhBVSf=*yjB4V_c?xgej&9Ss6Nln9iHg z&pixJK`$1UBSmL|zH&_S!F`O-10mVx(*$?Tg=u9Rd@bTt4sYg)wRQU3m;VX&CR}{w z3!%)Oj&z|Urb|e&wXrw|zo%94!qSI)q4S=8L4wx@$2!6-V)|HLu{;@5x%#SD*sw$? z+|GQZZnKCiGwMuiZFPW6J|U>=ojDW6A~a?fp40If_BqQ=#^->+z0~XBJ6>iumS2_e zW?Dmz(>EfM`D@vcW%P664;t@$#qrhdB&(O!64p(gm1N%Wj4a#V(*DIobLY*;YZK39 zqn|AD*tMU{UUlt~)E9}}^Q4~slHP;aDi~0G&7DJt#vIOWEmkjrpKR1#%plP_8E zf#>2bPP^l+TxRl11AADnuUt0T<$C{=v8T9!p&R@KmNBLe!>grCbZ#7>$DQtk>PJ(s zf?ECSbjA0-w%T@mn|Dlo<9+{jvG+vW#9efF|H1vH6VHuyL8#j}g2}7L1*As&YtMsG zXPEI8Pv7ksit~Is=3~7>ybXVd@4DFd(>omywift4_e*VC2?*W(J9T|!BS|_mxh5Ko zDV0Rl^Xd+(V;KkhRs^e$1{g1#U0vou-y|ir=qPuGM0l&@vX~bRJ@{0~xNk?TuGG0# z^{>csVp@L9n1yf(HxM@3RZ4=~3=@gJDMx|$2E?Z@=jSsU7K&R}6z=4@d)>G?%W_|5 zyoFK18Oyt@NZyI!!c3W8p+X2?RSH4NHgfGa{B9B`*dLus5~U|~ft@|QDR-UwGjPwQ zP%W|A*2xQ4(xDh2-0yOwDD-+16-t(bc104yoEfSz+id5ah{O=15cUuM4C8;*x=DfqCAqh9g z#=xM^F(nxeD&$0>Lfu)FKT%n8DZF}2k*pXbMc!fQN2y1+51pD`x%i?(8ZpE4Pga0{ zoO~~X~(u{#|Os1Hw$;g>akmkaFR+(U(;4<36EF77W#FP_|TJlQ0s?QM|FO4ZBi_g@p3Ib{#mc(yf62b~z4 zcRf?GfwThZ)shOO+p$v`Ly_xyO8sIa!^hg&{5})z35@vm^RIkX@;!fQN&MD2I-6`o zN*iYm4?gHuvscV-E`{I9Xa9&4SEO&JPI-Y=)D797DDfFtx$$&v%6kes!67A|pq8?| ztfySv(Rf&9SuWTLjgnS4T-RFi@!OzY3!7K}l0!0xc5ln{D+!Nkynt~HRnrUbduVIQ zb1%oDFS%0x-OA&KV;;R%&hN~&z2D>2J6u7}ZOHl1`szJFZxF(&Qp3GkX_p(H2LtMG z5FC;dBFl@1Bn3J`WSgbuAN5SU7>a+TEMjI7T!{bjk%A1l zzy3)r|8e%qIhUBT77l5Veq`QG^Qf_`*0V0jL(duu=`UgGY-s6G^DI#Vrmo;j!iy9$ zJ8nU^cvmQ4?NCt5XYBLs{uplilj_ydIq%~aEgiCq-dPdGlcveX)fTQgOzezFCI4*C zIGoKvfPn0EuN(nlObYLC#x&k;Z_C(Vl=^*kl|_@GexD|3+Jj8|Pp@l8q(-p;ONiFD z=ZeF@rF1&2EI0po)zORIB7yLhuS!8Jhm8VGw5A4L&kiYe_sOXZEuZ`1hhiaJP;YiD z`4~|vJ&mgx@(hYln4G=FyU5>=J1A#C>cKORg9wX_%J!b!rV`r+zX`n8GK*uX5C|0) zzso!Qqm+70OG&Vrv3y8 z9T0iXz>c%82G4KuEYE~26r3m95qV=|!e0zD?BR3mSydc5wH6>eerqvn+3!c41^Zft zzV?ISdVmxn8(K3AbNIL3**H^+xeY}&^G-PYs9S(h**a#v$+&4NI^hU=U z1&8FOMJ-jq!$S+r)sY`Gk)am$rhenON$4Yzec1C^$L}9`HI4DLy1Kw1HX5hdi~E5r zuRG(X2bJ&tjm>6nMMN}HA5YE;(T1k!ToEyx zm!B1%jSw6FlztbSbz&iT@vF)WVrCk9n!fCLj$}+qK1%gUk=PbudYU)Rr&20p}xpTZPBHV;SD7aul9 zcx0+szjh59|D;1%u*J#L87s(_g*`n!UN!Wh%YGSeBt8#Wt1_o&JNV-!2c?JSuy-%j-HyS(<5&MyI8^YJwcEDZ7bPkx7^BvuChaY#LT%FCUwg zDE6k^w&RO+P7!5G4 z1!)V~71tUKFNTk3G}3Fw?U|%Ur1o21-at}-!6jNN*@hoU)<&hDjTE1*idJyi+3?o)sxnLb^{De6jpKJb=h3^_Uox)FbR3T~^=KW#tgNh1ihCcuuN8SE(~!e=b$9u2NgzcC z9eOgWXmG$za|y*kmH|gdGvLHhbq1Ux+>Yt0-9dSK{qq-#UAcZZS>bqL+qSfW9Ul+Y z@ADq#?eQ8RfTZH`RvmugRj&unGV;FmmJuI>)7*;>q{m$1evh6D;qBYT+`i^x$!7Vf zt?kqL2kfiGg0YL3IgP@;G<8W=h6b&N5j(Scddb|qxY27=Xq&e}UT#CJWcP!M9NYHh z7|-P!Be)kH8Of&l>#+~?6k8Da$=u>&w-Xbbi0l`%EL;_nROP1TnxCAIypV%`B4m2K z(H#}Vsn$C%-xq;tRwOOI0|KW8vMBD!^3%AY!_HOuy02#8c%LV<_wC4NROZB*B`)&@ovoe_`^?&cyZcpXUvDR1L|P$Vi;N8u zo9fkR60d3V$!V*rGlj=e%I`{VJk0khQW+SI*&a?0xcJdm3X?6Ei15q}HQTHMDEG=( zGZxOG%VKW39vI2R-S@FNWz@kSUI9{4xDrjfDH5zkEjr?qZbEbQ1e$5-{XzkZ^5Q|b66`r*#Qrzewv@0T<3tfL56OTze1oGJ^vLvv$B zUZqNQkDA|S7)e|ee?4C*hVID)#@y#?Q&edAEjadaMi-&wPhNppw6&l2?aN7Ves{My zqM28qr@5mm%=G5qe91bwvw{76^?3j74XJ2DqIsQpT4wI4o|xWYt9qmGv6|GRwuU&X zon{3A0>Xy=jN&Ve_1km^4ozwN6fP9(tN-c#ZO#UhLpn>tH`*tAUC?ilWnr^+g!J*Z z@hdNaJU4U43@@baOM4;PVt|ys7Bidd4}0~I{?iVe7|u;`NE5wzwMMQe8ddx5e(%H? z^xRIqTeeBB%w=7jEcL6UZwSmrEU%r^(W%h47TD$GfR6r{(O2HzIX}0E(ra-Q@Kv~r zR%yQ{nRYTcU#&fF$T~13bokTy=_=ebTQb?7m`mo3YoBlABan9!YF=YM`k+l0Bq}a+ zgqzEpdp#s#3(_KQO${zj59-Lay4+1zT`);~}1XO_riWg~ASl zh?Tj0u?XQ%$CIkkw2CM6nzqpBx(EkGh~gd{vT-nGV11WOYva6#S%l-Gk=Jm$k4dcM z(|l9)&gc}UIadt{_;D!ASz{R)M}-7qR4Mfs&cxVS`RUK3s1q;v`3oE0_}!rKoSP7} zGwh!F@R3gGeu4;bJX-T{-c*OmlI0ca2|t&9+!>0I3t6_Ml~v}knTECJ^Zg2xZrOnV z-%l%*MMg{V4hbU!3%}%CE^G|S6UR&uI?Z3vS@7OS*}ze*QT(HZXN0t#CZ|>Ay9^`M zX@@d?GEe0C8K=g2D)h9fwfOC$y245`AGt%7T-7^Tc~~C3m4Z?eXf7{gJ=+kbwBZDN zYvqhL4#jjmHQn}r5>Q!Qz5QE$h6nlCu6y@S7dBh!2%PeNRCH@Gwb||I319o6)l98% zLSeD}OU%(2p*9Eh{U^Qz-s8Kg=x*X5@?Z&_R&ssyD>i5fNdhL_oNzHYf^;IpaXjT& z#F!|2R6j^r_HLG7rHaD-*cS~?MObd%)R0nLEykc6_;qNbqc$&4q>Mp6mP)!#oPy2N z8f;JiKg@n5OlW+!=!RbBp-aa`2WE0-i#GJ8N6Hc#qWpv(8Tmb|oAY{T(_HPX#@P9~ z`u4?qy9o7;j~NfDNnxts+!?PXl7*!mn3H^FUXr~!^0MZl3jV9?+LvSFyF{(a(#h3V zY?@1wS1a1=6b6?3l17MY_}mQ^5Y^1f#~#8elhS~Fngg=+p1eV=%i1iY4?bJDL1>*F zAyCY$zgb;5S!-mb&?%gC0xjq!dhbx&o0xMMv}d*f%YfCyYR051xf(Mt8m;31tb-mc{xkrz04ly08 zM$3vUZh~wE=P9qQD`XX!k}^H*{m^c6p|!$t*o^#O} z0lxj-lak`qT-*YvPu8hVL@Ohma8^ZM%s-f}_;q94(i84A_;*we-k!yI_A2y+KF%12 z9oe>D3?I}x@kIIT(x;g6OO9Ung>{tx#Sv2n`Ctc0*9FuqQY-R&%cS}(YsQ$bmOc4P z#<-1$VF%@lf8)s<`kVLm^@~P!20TY#Tldc~2kXT)YD92b^4a+$=h>zF;l@Pl0qVgiwzWFw*v=O zc?`SBmI#;%7!DwImJn?PVrKvztwIv-u-oU=BvLp!lW|M?kYs-St>~U05lunU?f*Uy2%T|e5?YuJyGQn zW#ZHhQ`<1-^pWR|2}c+g2IOxc%0ve0anD9(eTdkVTUFgLJEOz;6!{zGF5KDhX?z_P z!Yp!`sdT6G2r_!E{Zc!F*g!Xq1}yE3&}}TX7*4W5zr3ovBO}V4FYarLR^09iZ^m$q z*Da7Xzqr&A*GOG1w~u^x|3C{|I`2r*g$H$y9@RkLW88gA3)3#|YdRhWXw+VEU{wL9 zij~#tF*f(6c8i#2x*U@_i?43KU%PL?X14EPN%p+&##BdZ-8ML(wb7jM14*u7Z|R`t z3anq?-6`qNU6BjMQW94(ye>oR>8+w<9n?}DNLcxSD;6w9K~f8n@G=!cxN#eJd1CV$ zG}<~{J%XEFvv(O9WI3!&52U|2-FN#$d%jV`-XXux;o{trf6HFS+#;8GKcjwR`g=d{ zJ8Y(Jf(r5A#s>haKKYbLNqzKFeXI&8dvrco#wRZ#XSd>QwsIh($rqd%UVo>UnlTwc@Xdo#!Ag~Q%Gvr4w@sOpb_fuWysA{Xs) z|E(lC+Oir}DwfAxD@PlDST(ZeM_y=@*`<9#|K^E>ijrzg-&7oX%v%h+Mryp2DBh?e z@T9cw8q?!T4TXiq6J7M}>Ak*3*cf@|C?{jwZ#r0oRCJmRO6ne3v1=!-Gqrcn>@VXIuoiXS+`weM9nyk%yEYg97eAIv;M_CLOOLv>zyaVw052KB> z2d3)S8FV!czpa>m38Xk!`>^=fA-Av##~(YNY?qahbah>?7{cT|)#h)qp}RI#Ot;u* z^Lc-iCdr<0-0>3ec_WCCh?kZv#i05uQh%zOZ6>RVmGc zeZ>qDca6OD&>KDH_vh*M_*wTl-*O2KqNUeR@6a?+cpK)RT#)em8Sj;SmkJnq^HIs% z=oByaujx&;xDU9MSLj8 zHlBeE9UyolSDO}75Q1lC2=}Vf5iTPno}2A%G%K~dWZQbX@cwo6HmM9_4Q;dFYyVbd zs4iblPPRyXYBXsjYZ4fgU6WJ%($=rY{qJ|XV|xAM10aZv0XZS>bwEVIpd5tK0A+Fxr1UMpiKUew9 zS#0d~LU9V?`7RTM8~N?!tQmqmLZZ;P;lq13Z<)A!%Gq71GF!L$NDk4wO^$$zQxVSP z4ezlm?^dZ+9SMkt+GAVKbtSbXgV#z>%;oY)WJ7NH2)tJFgZaZ~PU3_3apoxNA(gD1 z_h{E!0u4p#zVTm*9J>~~Tm_QJbgHC6l9qTh4QwX(-^$r-mDHeZ3DffM@iz?Enq zUKvM$CVPRTq?#;efo}$1FwZCHAbM0~mhALxN)uKFm3LLV^^bEgpiL22e9i)o41RcZ zT-0UO^YGO#a}1SJB_T@#!D%cR9FOiv)0&>*9_Q0X zTNWZRPJP?0aKy*{O`x!#NMMYUImpBiKI+aUB=6?>uyo*-z`t1DB*rHx+c97{^7mS6y^yU?=SyXzonDlFJtz-Y?OW z7s61Q$)DMsRA=zoS2_t2j&V~L7^GKSK{lXBQmwP?gQ{@JnTwDsx3h-ncHYV7#+Rn_0hGOF3^a zJWRLJ)Kpj?gE`z;;NWSYgj*2vn(rqD}~Ejpj07Mn?9nox|JHb z@`a9{_E|W3F~#=l#%qVp>K6t)=&I->bb9)lVgdLxhYIDZ&hn0rQ%t9+P-!$~+6;=K z<+0HHe*YAll_0-dQH`6=S(fvM}PDU|M-ed5rOa1#UzVuKe zI~NV^HV!sttrmE1B({;6l_EDZIQ}J#rsty5k>ULO7!VeNFFB-`5v$G3qQlS5Kj<|@ zT?o0ki@XsB@TK=qY594;hpUXeRrlt`73$x7srOJwz-7uzJ@;07?It(+sSLeXw;z|Ihfh(KsbD#$d`uw z5MFIbz@{jWE%2L6115ajW`(4vX6|7#&gW4NttjuGycODh?7i!elksgE8^?hoTNu+? zfUgJS)tIs;McW23*dl3cZjo}lRxKtz)MjD2w_knLe$mC?V27!aMb--`OxSbewwO2j z`pY-`$@@tLc;ukwYa(rKqXS&0y2!@-rhJGwHx=?c=}~Hb^1z5h0_THaQ=uo@U32{> zc+!hezU;7YFZ{O3QT8EZ-CiY69SH(9xlEM*XjPd7Vq|LD_S(wsZ^a6GPnn6hjmagQ zk6jDY%@6%N2^^95TE?vtfs_*@3u4%Kp&!X;*^fpoz9!_+P(ij&-b`BL`eW`l&L6Gc zep6H|e)C>B)d{FSuR)!CWDLn_!;f?i=)Y%DvoEnH!oR&Go8JphZ;q@;)CwDPyLq1! z`3?M=DOq24dem$b-4IJqx60aIEPmxu+-3w$(xut+xBdK{vuQObM)sqh;1_#{sd7MtsmhTwc$vzw3V@R6=Eew3Mh3D zTX~n?MvIT-#*eTMtdtRQE>@jHxl)>XB^TCOR$OV8mn_!My$*y26amaTA$D>U2v+n+ zV}VZ`uSRI3#J<%~~$F;=D{Jiz8MU96wAQPTMEA-_j8l631sl}PQHdVX)HF0#lVVO4Duo5~0LL554 zB#XNc9W+qOLRNc+Zs1-MDyCp*rp)zuFK{b~ahPOsCWzMWz$sC9W129cTj!a~c4Uui z%kr|!V>NDR%gQ^--%G4rfhUoUjK~t>^roQBD59Dvp>xh|!^6{onIiZB)2w>Gy<#;` z?@}?1V%0aF);r>>9_E((q6RTTdIU7ywM@>BEycGmR`zUlEJ}z>(eTk)McS7~s3o48 zNH#a@?jz@p1g#iPhCDZyYH;*_lH6i-`2LXb^JZMAd_<97c`iC+F@1JqgDF~(MDL1G zc+@tQW^Nf@?Ea$qto3eRvX&>KJM83{*$neTnprIGUuSso?(H$J<&#vKYxz{=rk1u2Id9lnK#DMmEiX6O%}jE_?SM z7c!l66vKj=Vx^~OQ&eEnSJ5|t|C~(%`drroMVu$z`vFxb(ls1pU1h9eW*w00M{RQd zJZu|%gJJ@X@bWEgMN1g-Kk&?hpseDT5wy9oiUk>7KN4gRk^7^>q5dpu;1Jy6->;K?19HKN>`3_r z+}gaHm%g=jgj^teSfXk3yNYz!F*}sUWNORWynXSHs~xMoCwUR|=OxGdd9m-XANsem z2K^L?Uus{&jv*V^6E544yp8VFsd;fcoO!j7zluHh-=?^d$W1Y(R8yf-8`w?0USOM1 zp@r>JxyyJTFa`ZJugBQysZc@yiP`qDN;pCxH&Ab~Fzw|u_y-QO2^}B-xO2mPM|`#VoK9Een-96!z5%y#KB&ODK{5sGw=Y)z|E!;_a@nUJPRNg+ zW@cheJd!4DQB$Y^Ri1=LT4hN?kNQKcBctK{4Los@66e$Q7fO64sDHkd_3W2VF#knr z$iFE1-3t6ETSR8FEJpSuS#~owTCZqob(>+n7JTs2wTdd$Mut_z=(MKvyGvN>AH(8n z@*5tXW>95{ATS&AHcNZPpkV(>*Gbtxt`o2f&x4em7ZD5R>LxE=v1jKT&rhf~K1%ws zCQgy@sGaM?jz&Mq(NsBCGG~$zB6a44qZ?&2{SIz4GnZbTEJwILGe)^nxH>FzErYJb zUegrq(qb+iefDPTXuN)h`lfD)QJp;96k8YO{%S)E~|JeKXcqq5F z|B*^krbKd_Ldm)0P?713B&p-t{bt3`W1K@~!Q9XFQ5 zx=R2@N@s#0=apx0!#@mP1Y#0p#0sYTfE`1BQdk6nh0*`Wjh+O)T@-C30a)O!mB6$& zP33@?-rRRE2qf74^B-IeLV#=ng6$Oat29 zP8^6D-KerqeyvMaP7r>m1b-4`ua)<8JgysY(pO;^30dyMH5U+Y{d~AkZ(V>`7^E>} z!Lr6T!CGV5A;>k@N6otcf=2E(02B_+_aKOK2_ORxegRI(n@Vi;5w(*%k+_P_!EFCV zpC2%oC39iB%ob+x#m=%}`4H(<%YK9$#sx|f?BmOIB9bLZ8{`m5c4hEhl@hEiwn0bw zc1>TtQE;hBxASagh(>6tl_j6kn%0Ea1IsOM%DRmk@R#-|%sA2uXsf+U>6d8^rL}9_ zoHTHxqY*+p168Y`Cv4sW{mzP51cV7k508sk9!S%z__yJ#XzQ?gUY?rfK9Z7#f7(Qtw*(P&R8`xHj@E>0Q)LjOr@z`p)zCLE1{pX!LU)^rZAbx%tl4F%w+#v+VESQPi&Rs`bG!;Pf^ru;X?r?erf2b`> zJEaCC4}8Y~z2~24gEksQ1ojcshdF1J8;6tQ@91|!-?~!y+O2(eo0!s_17jUqTrWLx zoIOR|Gh!35XwcJ%>xq=9N+ewbSTRQoY##6b*Ys?ORHkaH;6b#y{(Y`UNY|si=loUG zOJ-b;x(I4uo-s|skN|20R9)+JT%JayRCVC)&>QYHDSvKK-P~jJv$*~}foxl(46GE=7iTGL*tbJJXjug z4#FLPaB3E=*#r3X=2jQ8y3qnvktI=4aV7216WK>(9J~kLmU%C~cqd)d*$Un)uQ26M zOK*iDs0J0d4KBz_NdC%=MbwVY%9HuU)Q?v|j30W?X6otJFuJq5NZZut+T9KV$7sU` zg6E~B{e;ibQka4x6*}B*p`yL_`hCq^T%zZv$_g)e-acb?k|$cs*L-|!@$3Ol`Ls7- zvL3MAJZf{B5J@?vU%!SDFO6B&UPr89?qVR!%%R9C+(xy#v~6+jFN>;EoMYF>@frBu z$l4&MHNM4f_l7l-1uGs(M+l|V!aIHXWvks*L1OgRD1J0sjMhVI$@u{jfE-#W$2VG3 z1={836=xR~7U@{O(+!iOJY~jE*=7J4%+@a2_DG=kYCwwGLiv9;l5oc%_JD`#wk>QO#Stm3pd)%nD@9-W!jyJ9qi4*irKnT7#wK(~Z7+7?NJfj51WV(Vc7i zOg%e?-Uw&AeW1>Uw2fVwDmdi7dAb^V4p(ldZpK704nc>=AJ=~2L+e$hvW1JE`soxI z8?_QWV5^sHT0R2eQ)|8Klp<5U(#0YPD^_=P}y=7@|#!;*U- z1Wmb%Zc3ZD=Kr=X?=LyXwK~v*O#a$;X3`5?H zGq~a3wmddAD1dQ2Fxs!s4}gjlI+CPf6&f$nx$!zG5O}zgo%!Ql4;%2%?RJ=1xG|Jo z#4W$7*KxnAI9wjFSO55XJ8zM!Fx@k(D5MKwBP3F~(u+cgD`~ndn{TPmc(keV3#r=k zy4~^zvL9DWS;{;(x;pp$S%Yn6Z~V`#6>y&69telFo{7Z(u#cA33QJ#8vTbfq8PYxEh=*tBgaZzo_kL;OnHZB^*iG=BUx)6%mt z5l;|hVtME{5P>1_t#os1@d{rm!ptqeW-DavdP#d%s^tNCgOl24sW6dO$}TpCg*QH^ zU1NSoYHzV|;odh>O%RH14K-3w!IEN73Bv;VYDBYyRLx&cE5Y}-+PCC9@tL#T!OF9X zoupenxpOb>LZ4Kv@S~xVd7Ct4`gS0$gWS*{c<@AeeXkJym;oZDL0Ca9=2=f%vF|9Q zvqUxM9sFjofpCos$i?0ErXgt`;3`v2s537nC|so^d>9>Xx^$s5v5VHln1{0#fPdAd z4FIg1mg=h*99fI)w)f5L(VqV0#+WI+#@h=<-V}K zzGmFSlW?|tC{cn)1pe1PO$5{AM*9UQ#wL@RAK=C`e5ujkvbWdrv5}R`F>~U9W2f{r z=<@5_AB326A>)x7bt%X|VqR0w%srbAzK4!@f1V`J?HTeJnPngPg#xAVC5_l|7if&d z;$hDL+jStFW=hw|{#fjt!8C(Zcq06h2PTxJ_c5-q<}5>S=@{GC8JC~Mwj?O|XUYc7 zEq^{x4a9XZlz>jA=9f}`Fy?{R$*z0nPR)ExzVWm25HLqMqnrQHTR;4OtrMoiFgM@m zi=E%tmudBt@5)nem$q2kP13^UPR)l4m%D5gz-2TEU_Jx!lSVUr44xj`x4aZ4@&_ye z4gx9!J~!6J9Im>b>ewFpd>EfJ2awhLD zz4DtQN`Zpxny78{;-jA@^|t3VrQd0jo~f>@q@*yj*PY~HlZ=n!wIwud(DXvUWdHz|xc z`E;^6=Hi1HQcNVi9G@A-&c9LQK$B=2h6zmMAODFRK4%o2G$Vsw~>;k zQCsbRk?JYAf5BY*7n1J4&Ho>M@ZW$F{!{qDzW`ZZ@HA3WDs8s44|U}75(J#{#@AbE zJ5SzBT)kU-^>qDq?*2F&8Fgn2hla0qdc*ga3_`uzSVQlNZtvQTRL$HyrzCXqvo!Fdv$$^Kr*VWMbfE5dV+wE?{u1mF!qQ_#f z9uheqrwOZx*+x}E@Oe1f=0Dw`M|)BPHftr}j$<<@&i3kZAhz1^@siG*T@B~_)%R;B zd17%D2>@@wu)_7*18Yn?ku3zj`2+T~6O0zOGT>;wVJ*n@1y4=I@ce-3kPwgv7`a-7 zh1e@-TwZLQk2l-;!kq5Bby>e2?t8MYoQ*w|$lj+;V5~&}Y)IRFez3_M>PODe!N`_1 zgtPbPa)@oZ@ZXlx;B7MU7Lkk*!(}(+GnQ*IeY-svNB7P>CZwe=jvRpa@&8yWzTh^L zZtxOSa9NjJ`~#)|VzVseekwQ>y^fkL;lMC$U&blCha3_h7yzee8>bi$oDygSJWjyv z_3KW3QSl@47OzECu%yhDzAD;c2FCkwVmI_{k~tOR$8uydEeBbYi`R>PTUbl=uh5Mj(mg;;edH8Thrs7!=aOQy z+(W!WJ`LZnRc~$i&9{U26@XV^C_D&}=UE27)2sD_e7(RoZJ50+*H0@wL!mNuTk-&3 zgPFsSOwDn)}-?^-gOYb?LaKI$Q;cZQK5IOMjzHlA>0@~M#jH?m`hFS3jmaX5nzUuH?ceC|~`|=U- zj>+fuE%rogTE@Fcsua@-ApJ@Btmv%NPD^vB@JQPXU*MRMdcn1lgA%I;yv?L!EM7%7 zoEv(SD(1Rs&!!LP$f2)^vT9y^U4E;aum{@1IER5jV*}~oTwZX~!hkA`j+&Y2P+P;27LN@VH>t$u72er6 zrEPA{8b~lMxauu52ez2ytm5`kO{Ohm?m{~?IO2|xfE=}NVDdNb(n-FO-s?obR7Doz z-h5xdJ+sVv2lfWToxNH%@w1cUS}M=j`V^@n1Mq-&E$aJdPx-_T49Wa7OJ_|wxo|YT zb8J@m^7}x&tw;FfbZnks@3tKACI;##GqtFB8Oi~Ki&DMA5=E4q?>=>N^PA{>J1}B? zYk;lr0~T3J#j|dz5T|#d%uLI$a!@t32b8{m^Nuv6g};&5a5CWiWD;YzUa#i;^!YAv zQX7~rr8?k5&}@yHcu!;=f!!Dr8bmF=hFJU+<`q6RM?paciS~E_wNuOSCT2OGZn0Hz zcJuSFvf)-L%nh35nV?|7`5Y2yrO`T6kIl~U>-ywejulPInBr8z?bjYh!dx_u<>WucJjQh- z9YZ|#yc9ckOreXmLcb*Kw82+dT>aN?UXOiEZ#$i~98frB^kIip9%HoW%wk}}i29^w zO;Z?>PZ=wGvH3u{!Y&yv$uq>M7Tl)=%0+(~70>n?0j!($z(=Dn+|&fOeqJ2R7Ot(p z&GP}?&KSUsleqZnv%ho;C+gr8K%k=#yueCPCeD6Av85K0{%TeHFL$D)49$;%m}fh6XUp|GG4J=KLv@|Rq1~kfrTB(26#oQ!X6M(PS z3|y8?B8WFq$Vk=-z)0EvvPVwxyl$~M!f_GU{Qt&H+{Ie}_bI9dG5)=8UT&mgWU!BS zVPp{1DjPq{zwwW_#!9J=d3tn?y&Bl#i`3|^slqR{MGecr3}p%V(;CQDAaQBuaQ$`k zx-8bmVz@GjzL3I4#$aSYNC9BZ`yS(~2`2N>!8u>G93Mm4!ANo#x ziQ*c$qMI(~kXZopR-M_I&k$UQ$XO_le&t1Q4O?uiLc&8K{eW9tyT|i$u7V_y2sJ6s?!kwZul@r6U!Ze9o zfSU|a2jNR6woJ>k7VCPgNe*ekj%X5XxVM+wUv%?>OYfc*%~er}%5cOhH4`~cZU9RH z5JN<8Umh`(Ero3*@?-Y+^gn)eEqslIml#_>Po_Ly_g&KS$Vt+7b-}r|9xiBz9R8UxU^L3-NyQHk?&f41s-=sn+z~5}*A6~b+}tFhf<)# zYEiIds0H_|7BM{-r#z;pbe`zPJH9oVzT4vLnf1EjC-^qnTd+<%55t!1m)8i|ruF!M z&Op-hXYfeOsX%<5=A8-{4}Ha~(QLJ^JxX>-aAP4vhzP{xDXL;Fvx+d%)KiB_YpOaA zsq*WYH{ma1OHU%Au)U@OZj}$%iU1e>YPhpLZX~6_OaBMVK>?W^-XKN1ls@ET?{n-? z_vu~}yHev8>?IaoSYK9RRKGN}5-wuNXme?pc=PMjtog*L=0Bix{~vc31p8mZ-TkwG z3a793zyA&h9QMvubt&0cy77No*ZcW(pR!}2S%cj&?5W*=pLIKT%@SO}Ep7b)#64Sj zCN-|TDd2)8{!^jf9G}-S@AENhi=4$D9NkwkGVBH&E0qI=WeS45<}~1HXy$p@Zoqgu zs_0PikbL2&JN|s^6pG^>(g$2_O0`O0{CZYA#wc}o?`r3C^@Lo@s8VxL&9_qOTMnO# zjHX0bLB$6X4ubMgf_>euuMd6NbAfj4a%S(<5G*}{h8$fF^)0m0+-IzJabRdK&CT1}EJFaCt1g$2-5$NDyOWY3SHMedQ8&0m4I%HkVKF z>&_SclTUWBB*C7)xM}}n-d)r_T_$JK0JH5P$5=k+VcXjetnYxO4r$boYKvhVwG>CfP)P6WUSvUmUy97^)+A0)erpL&g> z?f>K={?*WO2LIBB@P4J+p?EysbW}h^R$*Y@TGq4MNUiIWBOh`9u8J%$g)Y|fM2<1J z#s%lV^I_7+$PG**gkUj9l{0G)0(_-(zctt~T2ssibUX2XvMG)OnoO*hBfD}3(ET)Y zo5B~BKr9#&(n4Onf+)r8PB`>vOWkBLg@sz6Ze859zf(Hj|I|W_!tAH%R&S(k%zNW? zO8xYC_ej?B0PY62c}F8EA_%*ka4W~eX!X&0MVCxm&Ph9O^bgqaZ~r76{pDo(YkIgO znR6G1U_F3)Z9&;xwe^Rfvl<_AFB>tBuRAvy9zKHsnMK$BwmG*ziOaDhjHFXPqz}X$ z2M}OhjR%D`jiokS{%`x9ejR^F*@Uoa<}q2wt^L<3Kn1~;1|XMSfhkPRvR9{3xRWq~ zt(dJF$TOk(1}d+v51Zl3gV-MP{-^*%en&d8rmwT10+~&J9Bm++pZ73jtKCO)(S!TH z32YA0|F#Qak~WeqH^?7E3U>NUUG?!L>5No{NM*#gq$B+vMCt3)E9in?j~UB=6cuQyA_Wbt?y|<+d6ZZD($g!a(Hn4?CR(1 zF7y;_J5sAXQA1dPZNtia9Cd`&QdnZ;9TA5cJowYdUpwNRdFRX}Jou1`T;e02>(EYxJoA(9h7QzxsWIp}sk3sy zG_uBAK&;jFArr$&cNWk{focfJk=RRxxmDQx5$AWV%Q`61vFc3L`l_UtD$78&*hY3Y zmZz0@CBuu4oPP&nPTTmkasw;U@-4OWY0TMJ-*?uH+)ZlDGTAjrB3trR-#xi$zvhY= zGBcnXssq$-x>dLjqEG{yx3QntrFohaHQmyWyw?kFwdp!b+R{F%!j$mv_H}usoOs#o z!LU6^iCGs~eqHf=NandKe!MNo*FPaLa4SmHF?(3g*b?TLQ;;dvZTd+^w050Xx%P&# zFMFQFiyW+5)BNt$16`?f>1C_Zdl3+!zi9*JI1N!@DCOwuRuMh#+mR>|8YQsr`0;a{ zgXiz9)7k(1R($SU0z(P%q+HvpTO)CrevJlPR@#OG20vi;cEw#Cn61o@y&mA^gHEyN zLLG5lb0eg+ZcFs;SaZx|yk~ro`@1QXI~HzCL2Y&yUWC$`iw`h(r5sL1Ub(=&8ci21G@B3B~ zFZ6Y4l4Hh%85)(_0tNuQBKrjOdakW{Hz_T{dGD==6fxy^(=dY_FUD+Eh7X0_4mh=9 zH=Ubq5wsAH@0UPZwhw(AW$LCn)9Uiq?LTw0Xur6fXN^(8tD7{U(Bj~_LO*UD#C<9p z96;h9dk!d_&;=oSnW>J$<~z)fc0W9SDWW{*CGQ`8Q-rb-5x!!7zP?JoOlY$B0OP#_H@r7RlMN%x zDZ z3`IUIpWOH!#h2?0zo#Bkp<8Jr37-&VKQ?^{%;j9*eM@)(vXzNINog-()AI`Nkz|V= zTytXvt`7p<=5_=%!eLpz=9xY;zt&dOJqgcU{nI5WEK`lvFT@ z(7^3{NF(IQziSeAo@&pV5F8$VJ^LzK_{k0G>P!-ghU`d?7K& zr5EEvq)<}kUq9rUF%kN*xc}jcy1ri$|7*D{=+EL>5;2^Td(1omxw}i`-NOr@l>oJx zyEew*1aY{qnWydAU+UuD(ahxtZcfyCL*ELlcy|MpsCVBh()@T0Q<2t>83Wd4=#7}A zQh#1t$W6WC;!nVZ z3&7KLSaECFo4Md%d?&`@lJl(WN-rwzdf#+Cz8@JNY}D@p)tju?%NwadJy$m_)E z`kVX4|H3@}D`UjnHo}=C03t*j0ybF)#*25C)ZW7lBz2g&=3d?%uY=W_Ll%PRexl`X zo$j1-p)h;-?nq70N>*;%U1O>d8rM6Qe;id z8}0SBbi^3z)K?24Yw&?LiR3k9&!TAbbm`W=H4cLj;xXQg)z~H-ZCXt_O=L?4Q}dPA z&9shf)EZ*iH@-AM>x9>dZQ~h~Si@&3vGw+yu)bSCjMYJ&=e)rp+2@brXY)ev6QE9e zie7$mj)a_u(xtJORZl(@37jyG>j^H(QXDndDHD0h3BLvdCpQI(&Q>dl*4B}V4Oc|v zC3}3)>%~}HerWl+=0b3$q3rgTmY$7I*LXfOzN~$S5->VMb)$#0sm7OFcAVH=std=@Lcc8|H56!vve14sP++@3z zJXPxEaTZi76PkR0?F*EO{+O0O75M`uJ-rt_{8VfYGI+iDR=^@X3y@#osUQtzK6|RN zu9S0E(i3Z1ooBK3=PxAgLvPw8d_$i6piTl&y*Q{>+{EEV`3L>e!c+fK3gZ0| z9)A#NEuxM}|Cr9X$Sl-1{U=%RKR^E^5VAt-M{UP`meklZ%k*q^ZoIZ@n0br1ot@pr z<#ysE1FM=ufTtW04o|(>GqKf6%!8ABK#a6@u9G=cYfrQRC`k}kxI zH~rk^Ffhzp(`|Qi5B+r{J*9-|Sx&s!yBWHbCOwQ5V2__5gCd`>9KhkY z0$9xV2v5nWXVMjq+Mm)c{hT5=I~0_IYGU)!yWP)*Np8{(yYrEqE+UtkpA%W!I*p!a zyH*W0jDUXlJijxrIDapEjfCNv-N9D!kNVg~kSPuh+;%G###0f;G(=EeFJ!be^DW$B z27VViA2~HuU%J~1e1E`HY^MZUW;)18!U{t0Gxyl-H}aW^iyt`|9)^H3R|r=$+UVsN ztT-3ZAM8~A17?OFC8gskfJqOic9f7HDr^fPj80rQ@_C4xeLMD_)vkeU!v`-zOpCVx zr}9E+?nt&!UnR)2;9BSeA86VJU>Oqw0DZ)=gc#1~wn%~NW zry9HSJ8;o{$xRo{L_%07%#lohWATulsy!i`w^KW&u&)_O5sXQzmH6YjD%55hB zF?oV;EBmGej=iiFpch(}3jY~Uh`&LYFcfEVOa>3tvH?)ff7+hinc2mpyJwrfGRjwoy+q}9i#l}=>k{WMf&yGGoYp#he`o{orNl+ zT)N=a7DeH`?SZ@V!tdMtfL*`257d)iv;_S<+V;Pw_}mRxUT`Qai@+$iN=I5|o?Tyh zpdbV>6AGp`e#!JE`r*LzR;d`1(!XY~(pGHni4$-gne~m5(txS;|M2|guGij50Skgk zwUO|<&YqiRt+jZ@cJ>d*0@Ae-TMnUmYu+hEKx<3oy-GD&6j#6e{%K{W_L?e|j>Swk z=b8yq;T$C7;zjM$(#2-UoNrzRMsM-L0Jtg^DVi*&mM(3HcYc*opd_~sajhWG@E!-C z;Rx>f7}m>a0_S8j`GQ}sgYe>2(N@ns+*UK@uT2h7ldpBitfY>C_^|F99>Qc|Amcn z>00}@E&Vpp%n<@SPUpJp2TU>;EUHZjyBlqn%qFy> zK6N%=8otBntOni2Rda5=qEFsekCgjWpU*Ofbrf4ED!!T zN29y<4JjG6rZhzzF_P=_i@p@-y0>SlWmlALpjFJl^oyJHCktGjcwbZ+<@weqXxf0w zA&$vip~^9L_D`(vx=5)h-A${Wj^4}f@3z-Mk?0uGcm@2oSz$#+v8`%dZkJ+E>ie>j zQPbSR#}vqjwd{|;?^rBBa|gJ6UgSk$Kr~iBa(ArdK0b?Pk7$xwoyV$#2VX}`m6W3H zbW9cR)N4V=4MaAa6W8Ek6yA*mM+!NZF-n=;g3Jjem6xEwW7cY~AQvzlDNSMqS7SAT z*H)NU51PF!Y42%&e=C~L@-i*p@`F9z>jg9%%BBpt>B$hOpAf9G2|CR%VQRmjCSf%T zB8{478o4c8+>b{Zhlwb!h%1l%P~qckcUMkSda&5_eKW zD9%7}o-WUw5Ikzv`?kIOGCs5+MowgzdKbHGQfuf6%E4<N<9DK&;R@n z9OSZv^ZQ~>9Ft*_A>tY$3huH zS1k?lmWLv8eQDb!2G7IcIA{xg_4 zyV;D^{XlUs--i@hOc!-%qbV$^&&qnD~sqPQ>0SRv2Igm z#qluB?WcrpI;_Nx@g_0dNje+)zT>m6r=@)B-Qg3wF{h|qu;k_bn3K7m&PhEQ_c4~2 zeRp}Zn9z$ru)F)RYsaK%69Yz-BLvuYoeeb8n@r`|z0ql8;)Y>TtbJ)>I`^im=<_Zq zeTvjhF6mdQuuuhs9Xb{*LSI%Q{fkWoRJvv3Q0Kk;ox3oZqV;$Bk#WREEmVR7dwn4Yw-16EI88)qKE;K6C|5zb9S!YP)&x z4yCVnK4gC)AWD9S$U|25=$g|c3$_JZNrKF;eLCU)>RncyXdmlY;EKmio&t{@jtJQL zb)7&SM%p9&`qqP0OV;cyR6THL0KdsA-@K4C6Z~j4h#N5w7{A#0%r1`+< z!X3NHh2{vfV-$K)4vWAR062%d_j?7HJ8xG>A=ltnX@8KWnKTwt%{8*CAFU$r@m`y^zetzT{UPhb^TsBoxWHq+w8G@ z@Apo-2XAMN#+{_l&S7toKpk5usQh=!BD}oMFTjGPxNR+}YBFQ6C4a5pDvgQ|E%A+4 z_l82+LD17<)fe1{y>3LX+EsdGM+Bpob8Y!i48+TGHr!klLRk;(*)6D9CGWHXE$_Xq za?>S6cwkk&2S@b;TT2jdTn~I3^+>XV6<95^fF0Wx2;3$^H$sKzwL^|Ef8|a=u1Qi& z-#y`7fY~}xN37#CU{hRn4oh3eI!s|IHsI$ssuLMB1jrEvu})6#%T4Vf_D1_A79Y@Y z2zad)60P}Ezq0*9-Plasg^b~kfVWN=*jbO=0erFsL|3rVAnf>XNz-OTw{tfTIeP-o zk;r?RJ3G;vz0ekyMuckvH-x+cvNrE1jpdAwK3fGm?HnV4pk^!MzL^1mOjO}O*7s3B zk1_$7#~`N_(YVXGRdL7;4?+r>qc8$5$4|E}8TqJTwUa!5Q8R}6e$^dJ0o8Qi2ndW%KdQZ(n*z8H z%N`-MaBatCR2FsJ_`gXp=IrKw*4D||J8T&%LbZypKB$G%`^HllUv^7%r`QZ@Xu)Z@ zNRLYwn?={Se%;Z(AUDJ=5Y>!y{|7;$MJpgk)Gm7p(x`YQ_AW#e;CA;KVk5gFUa$pYfR-3bbn?8arW*RLQ9L0=Ista`vX?sD1@oAu zW&?TB?^A2T*me76ax4so2oX;4E|pq4kn&gV^T-Th)xQ&|1<2tuG+{RIGSja;Jmq@! zW6GZ0shZC}-&od8Y0j&5p zf(XZRLH6iRt-ATE!QlS9m`oyI(a(!NVNM{Z`5%BEd=)zvKZ?s{JsHJ+M~vjVeJ&cm z-V|~?$xPB0caoqnak#UW-a~u#~Ef!$E z&SvJBFgD4+84QJmQfZ3w9(1tbtvhaUFFmi0F{KNh!PA`n^Pxyx@`xn2koq8NrozBF z9+@YSc!&jaVuCI`Fa^diNfYqQgxX}F6fp|WGc}`tdDm#|eHp#p+bcih=1fUpiV#ux zYi4RYN_Z9x>R*t_(uQz|RS7!8K1K{d_OubFPm?l8^;mJ}Sss2JQ$G#J0-15lZDYgi z`cBta6{SsBnkS|xJ8DMgTi>FI!ZrHwKbKrCC(GpIL};mkZmP{U`CiG6vCg%}_Ea=F zXC%3j110KXPA<>rNYnU-_+Y+^)r%t-EL)g88Tm zQILx|hg&o+nt8HKm#p`1fJOJy@-))K2F#>Kzir?8R$;}9sdXL)@(vUY zUCEEwk2q~?^B=ck$s?>&?qF0CBs3Mws%D!2 zZtUNmp-{5`8h%H>0UlEWB8$Hy{^!5xGM{b%rL}y1z+BOvQ-8u%I8t!T;7&)DKRClo zOWQR5!wvPDE6NCLsP{1ob>oQ<=??+H9&m&p5otJO{O7CdC;66FtAgNf+NKIDuTmej?@L%2J&CcQ21l?2Es+yT{>`T z-`L}RU|}>JYE5aML9WxiHvb}7@RrLp&#({Z1}FQu6LaIzen;2)z1U4Jy8KNnokbuz zTjD-vlaov6!D~wsDTocdy3f5V4c=Lh_oi+`;9< zHP~esH0!ChPOnlHX{LxM06j6KX%ua(I%ueSwxoHS$w<5W>4-y@_|C7Fj1{k_)cd{O zIb6m_p}3792aqcSUqgv4+iRv-Rk)0ui@^21_ zUv5x*BD3Xsf_kGDg8r8F)JE%LDKER(X0_xhT8^mFhoWR%HykT1X+$I1%ks9P_h47n z10E;lCwo6R6B*sr_4=ET+0tegk$IkDTQZ7X<>NbTgmdndcb|xp3|-!Q(|uNZZalD7 zMA0X95ZRw#Y}pAh`Ze-X&Y`)h=W`!=?RLm|OYL|=I_7(N*0^A-K)7q|Rv$AsSUMbt zLHR@p;BSXv+&*S(Fe;$M1gK?PwKRKFxPLzMmFHT~tnw{rKPgR|Y1(C_3y=OTqU%wA#-wb%l8hQN88B!HjGZkH?=-4t_TsJ<-57#7 zDwFr|Zd#Jfcit^QmQQ^0{5lQ{>cL|7WDMB#-f=zK=#2^>UP5*T4B&2?Qc&e8e!3P^ zq4#d?FB($OX)n-)8IS6MzK~A%=!VGDYgXlJUe+5PzW(|ODu$Tbk;XWJL1se~#MsMu zOs^6xpmObnrtF)_QkFG@<5Z0qe@b;}{z_vnCL%Sx_@IeAsxE@D;(fxzxYtjksDG@#71 ze~3t9B1tN0p_=5p*?ZdRo?Picl|nZ|{dPlCB@g^FIuO}&qveN!l(brK>JW4B^< zP?2CNtx>tHBV~?rat`)ApE@UbDT3Smg;llprK*~GJ$-U$y?_aQ>|DJbG9H)XSt~Zt zwX&}oktI7C9%o~CR&E+0;#ISyNS8VkSN~{yK)-aX5~{R-O8c@2V{7~Spaa^@I#1{GqLm25+-q?*tvqJqq93N}@@=;kX*-xX z_{j6~)BzO7+m%fZXear-$){p-nFDN}vxv);4rjeK{OBVPAHE3+efmkdQuipB3{1=AyFRi$wQ=lo*msMCnIu zz9h&`AEW+?sTW)4pON?Cg*oocwSuiExm=Ntf|mEV?rxkrX`r(n_Z)wRxEdo&)g_NV zrq+^~0_DlyuRIJN+o12R9a5EDr`VdUulQi6;`pffqX&j^HSgg^p9jhRgWA&{ilHIN zq;foWAJYKJer@jMUO9hCp=Ve$q`}Q?-St&PaqmPN?T<*92cG2JZ0d2c$;HhR5#=HJ z?p9P@=5p7AA}8RN_e(d2H1kyTy^$R&A``p#s^gKZ#S-L3P%Tbd%+q#lC_7E8ONZd% z%NwOUkRKtS*YBHDf!cO>3$>uyh3UXb)ixrBhiD+OJ0(bNMoU!m z+m7Gelu}Z>uDeieGF@i<5jVvL-d8n*PUc=kWKS%lgpcvDXYT-Gq<4lEUWwz!wjehF zeA;tV4&jdW`hk0;5)kfKySBPy!fWRaI_XBqpyT0OA4>(Jyy`11Zi2^WH#907kj{!T z(OB1LaWy?tX9`kb7<*BwkRoni+nMM@l;7)j6YctL`1+HW(jkVWUVF}`Bt6;K*0E=~ zk-o)mj%FWHFLr%?_}r!^QU+y5!=j zM=70ZjH5Ym7Q*L_tGMsWjM_mNI*0y(>_k=L5lBZD>Ds=P?5$7;^FTlqZVl5u4U(>! zjXs+fJLb=hvCD9`y-dGyVEH|H?bC|ZjDDSQ*+j&S|;S>C+5-{nIaIYfr4_GBU5C^|j5MrZ6yJNHd<0TJu zwB?x8m~A2B%&zJW%1U4IcB(o$GwQ34Mb3ue_TVLSws|4?HLLqXp}s12eK)caH$9lj z8ezrCDO*|hK*xFw9+87Kk5B4f*}P@Ff!dq(A+hsKbVr&EuuxZkM3&dgLlD&i&TGk3 z=&%higw9_n>=^#w7-EQiiQc@P{#Z6Nzq2Aa<2-CivEE!fA-Rlo6BL8m{5!x_XA2a? zIE69BQ=$_Xd$I3zmgm!6l#oXhEc@176fI3>g(iN-xpuoe^G?hwiMM>beog6_>w_}( z=1<*>Es-mb7f1Spp~xD<8mu;C$JC+)+`81??xA2>V#(U;B0thNNf=W(e0eT&qu&`{ zYPF2(9aVnbPf!A->2o4o`3+68rb*l?RC1zrZ(Q-LzDVpO!|Wp^@9u|QKCP2>56nX} zXi|q)0HRTNDu$E>A|+SKBjRUV5AAw(W)^=K3}-}@!L@FBCS(`I8% zc66`#oL;xCjM8@AFZMAyWU30qCV}Qz4RZOle9ucf@Zzb?AiuR1FL)bzZG>+^{T4Uc zwb-ugUB&H>q}w1n%h(n~2JVW`avA5WoyT6-eA2kwkZLi!v9(iZ-yzLvMi2`ehg+VK< z6xSWxxIMAVxXyOk!=VhE8TKz6dT_t4pYmTvu-^|F3!hWOK;kOw2F$t4{MJ$>FO(k5 zFeB-%8OX6rC<&8WRc>b_q*$rJ4w*DF`WZJN%>l3GkrIRlpuOM$K`JVn=yT#SEdvHg)3u<2tM`XXg{U0?=lQdew0PCV&`0y`NfEQ6qBTwLdNJR9 zdKrz1idoZuI)Pc43w5P1rJI4ysro8U)kJ4{i^g>XwTdaENpC%S&D{20qUtkq$$iPT z*SEi_=SEmRkBmGAWb((-vYH8EVX_q}o#=jiEnb*;k@eUm2O;zgzb?L&VMXmI^3{k? z6itrGxqe~1ep1Y`S_8z*h8&WQs$+ zE>ziaKqf3~@m=Ffr5-ITKh4I%rq+3toGYpG(ni&-)fKv#yN-AY)YqHM`DCGzKc4=| zJ(XVGmpPt7?Qpx5hZXCPZf^sbDKkw1UXiCh;yk_6QRrA~qEKVKj8okg10UZnf(|Bn zZ-nld@DzZvin=^|r+vQ@9V5L2g+H2%^X!qmrix0Rba}0umBOPz2M0Vb&;9M8pUr1PPEr1O(K` zC=emc$Ph??gbPV7@9664Rb8w9^;)m{M^%}jw{6@`@?<;uhr$6Z>c#gysRCt8yF)`O%Z{bxB7c$yOaZ;UEEHAwrU!FzB zKGB*s1#5F}%S#i=yY_}WOS5u|K47UoIfZJ#Fi|x)(vqAq(N_jK=}vN=caC}O?iea9 zO&>wT(I!@zz7|p2T@A(a2W$3q^bgo5{pzh-@k^ujTY}0rh&f-)Sw2QtSEVG?fNh?B zQ_)|G^YO+G^tc$GONf#)JKUg=?49hZmT05Sw*8iuoBI;=hnE?=M{*3t4algV>9jm} zACyuOtB>AP(D zDzd~DX}7j&85ZPGu3-B~>-@GH99QblyJ+mmYRYKf^Sz>y@42jQ-+gSrLft5B4LZWX zT5m<4eGIWk=~kpWT$jF+$FRaBuVulepvdBe2{O@gR1FOXx2zM zRZ`4sUuU&AFg_Czls+VMVZ|%Ij@S$*)lTDK)?Z7GnfD)X90l%`er6iQq{L z<)4CH#(<8;?8&w;nc?pH5uifYQ-1@tY$40Nef1dlcNC%VG=I{wM0xHU{i(uB6 zuYg0d6RvU33YC(C8_jpYY}X2lOINhJp1tu*zNy*VZb)yg$8G*)lXCm<#5d;8B)<_p zdYN>X8%0=)^JSS+lVP+S-|W&|iZ(o*vGArvg!miTV?%`2+Am)p(Ccbc4NAP36Zr71 zBqYwH6PyRfKiOWWp*ofTCGKry2hx-Xt?8H;b2Qw1;aJ!-s)-g6>rh1by(L*;bG>4P z$Kf`Ik5`~HwOO^?9$V|jcYlgL-G5eIoCa5CMI?db>?{59D1T)07kb9yL#G8|R@4i7 z<8qAFv785^$BUd3vak!ciaxFSy2<(r{j^EsvW4uvX%Nw_wm~-F%}e3UeZvC#`-Nc{ zq33cU>OI?NI=bHU+X~D8zTw)fn+Ez-JGW-B?)mua!&i1b_m$nmI|#V_)EW@Py56gi zP^AvII|e$=0f)Z2Grv%UEAxVUO+~krj!oy_4tVUN%Wt=|SlLZHp_%>o#zO5Y{7zXR z`=YKw1opU$y3b6nkGDJm>q7sse~u(S9SnADZidd96wP09cAJyWYqHfjg5 zk{wv)V47kATVTi(KCnmUQS!$qn<_jF8S9I09W137lBu2gMtgtNn!++tG>x7mq+}*$ zkC$tXL@g6iC!+;x2f`<-b0>fX6s>pRy1~J@J=*Y&oOE}C6VS<2rRYOLhl2cq&_nb` zwo~VJZ`Vy;8=YdQkWuyYPfhT*wn=Ptwm$_l+mhiD=mo^%(Fi|zzQ0Z6|^S|w1L z_lLx@$$cCq42_9`1Hiz82gt zQ_J%DIN`gfF1C~TWL0gH-2;{%2@Zr{8}>jN@R5fj8+k>Vmix)8gO<3xh-v?f}w^G>9>Lbkj%l z@JO$ey@9oG$wyXsW053Iq08!LJun}e7&M_jvoE^Gl5m0m_zBJs>_?g$;9KT7ZR|nkRL?4<0-U<$ll0AMo#yc-L zT%WR?Yc^5#D>3nOn2Ghd#!Rg0lYj-Vy0NLC2*aEyaoDg1bo40TM}*!uu+X96_Htie zR?gtGiAGq`*p|G4{F%sU<-X#~qnA5VPc+mX@3AQ`{NkOc296~00ei-Ri8uv93o7hqq|Sf= zYyxtC(B!oRP7tchORU}72wd@f(thLc1*6b+6gAts|G1ldXKS6k*_Rvd@M}m7mz!(P zQ&C&?pszHr=sdSV{b>7uUkA4<1>H^Fwth{%=9{R7Jr`QWG+t3VoEct}&(>s=e$G!=QonnzoZ zICn7K2xOwT_7Y#D`TVS)pD1s(d1@jM+#1}^ye13maP?TqY2zvIjwjc%B{**oE2vXPBlHxLSmr+fX zfsD-V%eIm1AFEQMxHMmgbDOHRSh;z+wdyNvC^Av2-1ziY<-}#*l;bEOr+u{FEA3bZ z#U;6GoP)!t@>fmAd&DS9Gor~^TpgXqLvEmbmgG*4MDix8>_vnnx$tzE4-Lrr zm^4@w1kqT?)+PxK(S{aLvh+bD8fizEO*6zc0B|P`SIsLU3zc{TsMS%7PfJN0OAw*_@CoWGyjR40LKX$&m#_RTasgG&t3YZ)<^$!=mXP(g|QSJLgw@|DAQVUk0v{?ufXE+`Q9ej4-z6={74X!1s(g7>q4B~@dGdhN z#mN=5x-V3#^WJPhT~5vbId^)$pXFIo9Wl=08|=M4&8gPb?ZZ!62ZjkE0|jn*bGPGF zMlUxfsuGybnUt#pVB&nb1aV|?0b<@sA?Xf-xM|!=zJ=Cb);1T>mniNwTqmBm!fxzE z_3T$Yzif1carIWIS=&A^J*bbdxihl($Y@>j!b8KZ%~x+EG)C5Jcq%rSCr)0GsKYez zIrzR%Y1*Zo>(y#@axuC#VQ#vd->n+v;~6h{Xm>wlyLo7JBdMi~xjyE}mNhPgULFtM zo*DvS^#&>}F^C&p)j(NMV;q09y;M_&&c3!{DbML)G*?`EcJ+5)cci#S3x4lg=PV+&*&}Zy2=ht3}Mq)i@!(o zSDs(us`9co_Ik?uvLe+wkF78b&RjRTypj+H%k$T{wj=p`pGoHM2%b6J6h%OF%Z|(Q zt4byZ`GzCKBg}wlII{zU3)lMZ+|!Cjzu?v2@+UTO_K@lEic?Bouo6{n8n-1xL1 z_FemPt)2ui^}GdgC;+*aqhC zzf+?A^d{_!Da#`m5^D~DKkCR}$5Xl}e_}g}*(%ORboQ|a@}dkiF7rdr{D$BM zdJkoLn=#fk?tHUX4F8U2{Bw?(`DL>qW9qaQCngJXnU$QiP{X*de*W)d99AndzU@sU zBkpQ5of$q+TK^U*g7`&1+Qb4PfRWS*5(NwIv5G}zz7PC!})EXu(3F!#sZl{#i ze;>Lo5;%`J--m*-)0Sc$7>Q02+iM`HA6em#kxvJBW&pUI6G>STYaq>kY#GI?1=EBh z4ZxMtHQu=V!K!Q#u@I;Tgb-rAKobNBB|0WKnvc?^`mTsN+58*?_e)> zG&;LUV*XGKJz;3g3Qu88t;22~)L3-NisRi2?T*a7sF!Wg_4{MvR)>+179FljQWw{g zAC)?^^{`zz+3j>8e~)8GMvCtKV%tUaj&YgM_aX`lH9cmk{|jL&jvc_(Kr8X*1(wq_ zl)D|6?n8#z!?G;ihr}I--F+>*@St%`))Qr{vRN=jomf-CJ`XK!BF^LjW%i};| zDXVvpn`OQc;6zPgyWNsp<98J0066p@sQV&qN$wF2nWVAU{aw5wv}_Hu;0RpYL+Q}p zKF@=+^|M=-=7m_S^=zdybOXBx>OJwR%9=@3rJEE3*-eBaR314hbX^E6ntW@ z470yQ>?ZzrBmqUjZiWCJ$*cjJCAs$Lq)R+K9M7aK;uGrj-~XhR>P)!_jO~T z4T03bH^sR-K!~+Q#oKD|dQQRN2#7%KQxP+a4NUTB@aR z*{efQ>r++=JwA*~FL7;ziWdfEupGcs;7C5J2DVyv=9V&{bo z3`w2UG90lJ_Tn!#%ou zke5RIZ_|4OKmZqs1N-pKUdC2ID`_ +Paper authors: Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, Sergey Levine + +Original implementation by: `Jason Wang `_ + +Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a +special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which +means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such +as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient. + +The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards. +The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the +two as the predicted Q value. + +Since SAC is off-policy, its algorithm's training step is quite similar to DQN: + +1. Initialize one policy network, two Q networks, and two corresponding target Q networks. +2. Run 1 step using action sampled from policy and store the transition into the replay buffer. + +.. math:: + a \sim tanh(N(\mu_\pi(s), \sigma_\pi(s))) + +3. Sample transitions (states, actions, rewards, dones, next states) from the replay buffer. + +.. math:: + s, a, r, d, s' \sim B + +4. Compute actor loss and update policy network. + +.. math:: + J_\pi = \frac1n\sum_i(\log\pi(\pi(a | s_i) | s_i) - Q_{min}(s_i, \pi(a | s_i))) + +5. Compute Q target + +.. math:: + target_i = r_i + (1 - d_i) \gamma (\min_i Q_{target,i}(s'_i, \pi(a', s'_i)) - log\pi(\pi(a | s'_i) | s'_i)) + +5. Compute critic loss and update Q network.. + +.. math:: + J_{Q_i} = \frac1n \sum_i(Q_i(s_i, a_i) - target_i)^2 + +4. Soft update the target Q network using a weighted sum of itself and the Q network. + +.. math:: + Q_{target,i} := \tau Q_{target,i} + (1-\tau) Q_i + +SAC Benefits +~~~~~~~~~~~~~~~~~~~ + +- More sample efficient due to off-policy training + +- Supports continuous action space + +SAC Results +~~~~~~~~~~~~~~~~ + +.. image:: _images/rl_benchmark/pendulum_sac_results.jpg + :width: 300 + :alt: SAC Results + +Example:: + from pl_bolts.models.rl import SAC + sac = SAC("Pendulum-v0") + trainer = Trainer() + trainer.fit(sac) + +.. autoclass:: pl_bolts.models.rl.SAC +:noindex: diff --git a/pl_bolts/models/rl/__init__.py b/pl_bolts/models/rl/__init__.py index 070ec666be..73b82dbde5 100644 --- a/pl_bolts/models/rl/__init__.py +++ b/pl_bolts/models/rl/__init__.py @@ -4,6 +4,7 @@ from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401 from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401 from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401 +from pl_bolts.models.rl.sac_model import SAC from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401 __all__ = [ @@ -13,5 +14,6 @@ "NoisyDQN", "PERDQN", "Reinforce", + "SAC" "VanillaPolicyGradient", ] diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index ef5689d238..635f86ed80 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -43,13 +43,13 @@ def __init__( policy_learning_rate: float = 3e-4, q_learning_rate: float = 3e-4, target_alpha: float = 5e-3, - batch_size: int = 256, - replay_size: int = 100000, - warm_start_size: int = 1000000, + batch_size: int = 128, + replay_size: int = 1000000, + warm_start_size: int = 10000, avg_reward_len: int = 100, min_episode_reward: int = -21, seed: int = 123, - batches_per_epoch: int = 1000, + batches_per_epoch: int = 10000, n_steps: int = 1, **kwargs, ): @@ -149,7 +149,6 @@ def populate(self, warm_start: int) -> None: if done: self.state = self.env.reset() - print("done populating") def build_networks(self) -> None: """Initializes the SAC policy and q networks (with targets)""" @@ -383,7 +382,7 @@ def add_model_specific_args(arg_parser: argparse.ArgumentParser, ) -> argparse.A arg_parser.add_argument( "--replay_size", type=int, - default=100000, + default=1000000, help="capacity of the replay buffer", ) arg_parser.add_argument( From 742943ee1a4e87091e3b08102fcdca5f471e54e7 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Fri, 30 Apr 2021 20:59:44 -0700 Subject: [PATCH 04/20] fix style 1 --- pl_bolts/models/rl/common/agents.py | 5 +++-- pl_bolts/models/rl/common/distributions.py | 6 +++--- pl_bolts/models/rl/common/networks.py | 8 ++++---- pl_bolts/models/rl/sac_model.py | 20 ++++++++++---------- tests/models/rl/unit/test_sac.py | 2 +- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index 92478ba1b4..817705ac9f 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -139,6 +139,7 @@ def __call__(self, states: torch.Tensor, device: str) -> List[int]: return actions + class SoftActorCriticAgent(Agent): """Actor-Critic based agent that returns a continuous action based on the policy""" def __call__(self, states: torch.Tensor, device: str) -> List[float]: @@ -160,7 +161,7 @@ def __call__(self, states: torch.Tensor, device: str) -> List[float]: dist = self.net(states) actions = [a for a in dist.sample().cpu().numpy()] - + return actions def get_action(self, states: torch.Tensor, device: str) -> List[float]: @@ -181,5 +182,5 @@ def get_action(self, states: torch.Tensor, device: str) -> List[float]: states = torch.tensor(states, device=device) actions = [self.net.get_action(states).cpu().numpy()] - + return actions diff --git a/pl_bolts/models/rl/common/distributions.py b/pl_bolts/models/rl/common/distributions.py index 2230b84c5d..33482a8566 100644 --- a/pl_bolts/models/rl/common/distributions.py +++ b/pl_bolts/models/rl/common/distributions.py @@ -47,7 +47,7 @@ def rsample_and_log_prob(self, sample_shape=torch.Size()): """ Samples X and computes the log probability of the sample - Returns: + Returns: Sampled X and log probability """ z = super().rsample() @@ -55,7 +55,7 @@ def rsample_and_log_prob(self, sample_shape=torch.Size()): value = torch.tanh(z) correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) return self.action_scale * value + self.action_bias, z_logprob - correction - + """ Some override methods """ @@ -65,5 +65,5 @@ def rsample(self, sample_shape=torch.Size()): def log_prob(self, value): value = (value - self.action_bias) / self.action_scale - z = torch.log(1 + value) / 2 - torch.log(1-value) / 2 + z = torch.log(1 + value) / 2 - torch.log(1 - value) / 2 return self.log_prob_with_z(value, z) diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index 74d2439dc8..c0f9ab1ec1 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -99,10 +99,10 @@ class ContinuousMLP(nn.Module): MLP network that outputs continuous value via Gaussian distribution """ def __init__( - self, - input_shape: Tuple[int], - n_actions: int, - hidden_size: int = 128, + self, + input_shape: Tuple[int], + n_actions: int, + hidden_size: int = 128, action_bias: int = 0, action_scale: int = 1 ): diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index 635f86ed80..039b221dbe 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -138,10 +138,10 @@ def populate(self, warm_start: int) -> None: action = self.agent(self.state, self.device) next_state, reward, done, _ = self.env.step(action[0]) exp = Experience( - state=self.state, - action=action[0], - reward=reward, - done=done, + state=self.state, + action=action[0], + reward=reward, + done=done, new_state=next_state ) self.buffer.append(exp) @@ -155,9 +155,9 @@ def build_networks(self) -> None: action_bias = torch.from_numpy((self.env.action_space.high + self.env.action_space.low) / 2) action_scale = torch.from_numpy((self.env.action_space.high - self.env.action_space.low) / 2) self.policy = ContinuousMLP( - self.obs_shape, - self.n_actions, - action_bias=action_bias, + self.obs_shape, + self.n_actions, + action_bias=action_bias, action_scale=action_scale ) @@ -180,8 +180,8 @@ def soft_update_target(self, q_net, target_net): """ for q_param, target_param in zip(q_net.parameters(), target_net.parameters()): target_param.data.copy_( - (1.0 - self.hparams.target_alpha) * target_param.data + - self.hparams.target_alpha * q_param + (1.0 - self.hparams.target_alpha) * target_param.data + + self.hparams.target_alpha * q_param ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -240,7 +240,7 @@ def train_batch(self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch break def loss( - self, + self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ diff --git a/tests/models/rl/unit/test_sac.py b/tests/models/rl/unit/test_sac.py index 2c3f86aae7..980064fb55 100644 --- a/tests/models/rl/unit/test_sac.py +++ b/tests/models/rl/unit/test_sac.py @@ -17,7 +17,7 @@ def test_sac_loss(): ] hparams = parent_parser.parse_args(args_list) model = SAC(**vars(hparams)) - + batch_states = torch.rand(32, 3) batch_actions = torch.rand(32, 1) batch_rewards = torch.rand(32) From 700cdbb41c8bd7acb610bd206e10833af36b7e0c Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Fri, 30 Apr 2021 21:01:21 -0700 Subject: [PATCH 05/20] fix style 2 --- pl_bolts/models/rl/sac_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index 039b221dbe..ffe12433e2 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -180,8 +180,8 @@ def soft_update_target(self, q_net, target_net): """ for q_param, target_param in zip(q_net.parameters(), target_net.parameters()): target_param.data.copy_( - (1.0 - self.hparams.target_alpha) * target_param.data - + self.hparams.target_alpha * q_param + (1.0 - self.hparams.target_alpha) * target_param.data + + self.hparams.target_alpha * q_param ) def forward(self, x: torch.Tensor) -> torch.Tensor: From 08ce087218ac30e983b5a291b7b24e7fce446bf7 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Thu, 6 May 2021 20:26:47 -0700 Subject: [PATCH 06/20] fix style 3 --- pl_bolts/models/rl/__init__.py | 2 +- pl_bolts/models/rl/sac_model.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pl_bolts/models/rl/__init__.py b/pl_bolts/models/rl/__init__.py index 73b82dbde5..16298f10a6 100644 --- a/pl_bolts/models/rl/__init__.py +++ b/pl_bolts/models/rl/__init__.py @@ -4,7 +4,7 @@ from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401 from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401 from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401 -from pl_bolts.models.rl.sac_model import SAC +from pl_bolts.models.rl.sac_model import SAC # noqa: F401 from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401 __all__ = [ diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index ffe12433e2..f12f291405 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -2,8 +2,7 @@ Soft Actor Critic """ import argparse -from collections import OrderedDict -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import numpy as np import pytorch_lightning as pl @@ -16,9 +15,7 @@ from torch.utils.data import DataLoader from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset -from pl_bolts.losses.rl import dqn_loss from pl_bolts.models.rl.common.agents import SoftActorCriticAgent -from pl_bolts.models.rl.common.gym_wrappers import make_environment from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import ContinuousMLP, MLP from pl_bolts.utils import _GYM_AVAILABLE From a54490148d5497321b094b7dc9137aa1b8bc5449 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Jun 2021 07:29:23 +0000 Subject: [PATCH 07/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/reinforce_learn.rst | 12 +++++----- pl_bolts/models/rl/common/agents.py | 1 + pl_bolts/models/rl/common/distributions.py | 6 +++-- pl_bolts/models/rl/common/networks.py | 11 +++------ pl_bolts/models/rl/sac_model.py | 27 ++++++---------------- 5 files changed, 21 insertions(+), 36 deletions(-) diff --git a/docs/source/reinforce_learn.rst b/docs/source/reinforce_learn.rst index 9ab85ac7e9..792b73ab4c 100644 --- a/docs/source/reinforce_learn.rst +++ b/docs/source/reinforce_learn.rst @@ -688,13 +688,13 @@ Paper authors: Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, Sergey Levine Original implementation by: `Jason Wang `_ -Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a -special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which -means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such -as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient. +Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a +special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which +means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such +as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient. -The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards. -The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the +The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards. +The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the two as the predicted Q value. Since SAC is off-policy, its algorithm's training step is quite similar to DQN: diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index 15df03f128..ae220e24bd 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -142,6 +142,7 @@ def __call__(self, states: Tensor, device: str) -> List[int]: class SoftActorCriticAgent(Agent): """Actor-Critic based agent that returns a continuous action based on the policy""" + def __call__(self, states: torch.Tensor, device: str) -> List[float]: """ Takes in the current state and returns the action based on the agents policy diff --git a/pl_bolts/models/rl/common/distributions.py b/pl_bolts/models/rl/common/distributions.py index 33482a8566..9eaeda806c 100644 --- a/pl_bolts/models/rl/common/distributions.py +++ b/pl_bolts/models/rl/common/distributions.py @@ -10,6 +10,7 @@ class TanhMultivariateNormal(torch.distributions.MultivariateNormal): X = action_scale * tanh(Z) + action_bias Z ~ Normal(mean, variance) """ + def __init__(self, action_bias, action_scale, **kwargs): super().__init__(**kwargs) @@ -40,7 +41,7 @@ def log_prob_with_z(self, value, z): """ value = (value - self.action_bias) / self.action_scale z_logprob = super().log_prob(z) - correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) + correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) return z_logprob - correction def rsample_and_log_prob(self, sample_shape=torch.Size()): @@ -53,12 +54,13 @@ def rsample_and_log_prob(self, sample_shape=torch.Size()): z = super().rsample() z_logprob = super().log_prob(z) value = torch.tanh(z) - correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) + correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) return self.action_scale * value + self.action_bias, z_logprob - correction """ Some override methods """ + def rsample(self, sample_shape=torch.Size()): fz, z = self.rsample_with_z(sample_shape) return fz diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index c0f9ab1ec1..20704f6fce 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -98,6 +98,7 @@ class ContinuousMLP(nn.Module): """ MLP network that outputs continuous value via Gaussian distribution """ + def __init__( self, input_shape: Tuple[int], @@ -119,10 +120,7 @@ def __init__( self.action_scale = action_scale self.shared_net = nn.Sequential( - nn.Linear(input_shape[0], hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, hidden_size), - nn.ReLU() + nn.Linear(input_shape[0], hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU() ) self.mean_layer = nn.Linear(hidden_size, n_actions) self.logstd_layer = nn.Linear(hidden_size, n_actions) @@ -141,10 +139,7 @@ def forward(self, x: torch.FloatTensor) -> TanhMultivariateNormal: logstd = torch.clamp(self.logstd_layer(x), -20, 2) batch_scale_tril = torch.diag_embed(torch.exp(logstd)) return TanhMultivariateNormal( - action_bias=self.action_bias, - action_scale=self.action_scale, - loc=batch_mean, - scale_tril=batch_scale_tril + action_bias=self.action_bias, action_scale=self.action_scale, loc=batch_mean, scale_tril=batch_scale_tril ) def get_action(self, x: torch.FloatTensor) -> torch.Tensor: diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index f12f291405..45fff1b6ac 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -10,8 +10,8 @@ from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from torch import optim as optim -from torch.optim.optimizer import Optimizer from torch.nn import functional as F +from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset @@ -29,6 +29,7 @@ class SAC(pl.LightningModule): + def __init__( self, env: str, @@ -134,13 +135,7 @@ def populate(self, warm_start: int) -> None: for _ in range(warm_start): action = self.agent(self.state, self.device) next_state, reward, done, _ = self.env.step(action[0]) - exp = Experience( - state=self.state, - action=action[0], - reward=reward, - done=done, - new_state=next_state - ) + exp = Experience(state=self.state, action=action[0], reward=reward, done=done, new_state=next_state) self.buffer.append(exp) self.state = next_state @@ -151,12 +146,7 @@ def build_networks(self) -> None: """Initializes the SAC policy and q networks (with targets)""" action_bias = torch.from_numpy((self.env.action_space.high + self.env.action_space.low) / 2) action_scale = torch.from_numpy((self.env.action_space.high - self.env.action_space.low) / 2) - self.policy = ContinuousMLP( - self.obs_shape, - self.n_actions, - action_bias=action_bias, - action_scale=action_scale - ) + self.policy = ContinuousMLP(self.obs_shape, self.n_actions, action_bias=action_bias, action_scale=action_scale) concat_shape = [self.obs_shape[0] + self.n_actions] self.q1 = MLP(concat_shape, 1) @@ -176,10 +166,8 @@ def soft_update_target(self, q_net, target_net): target_net: the target (q) network """ for q_param, target_param in zip(q_net.parameters(), target_net.parameters()): - target_param.data.copy_( - (1.0 - self.hparams.target_alpha) * target_param.data + - self.hparams.target_alpha * q_param - ) + target_param.data.copy_((1.0 - self.hparams.target_alpha) * target_param.data + + self.hparams.target_alpha * q_param) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -237,8 +225,7 @@ def train_batch(self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch break def loss( - self, - batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Calculates the loss for SAC which contains a total of 3 losses From 71e0decdab93c18043d3a1ff7c3df378fb4a147f Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 24 Jun 2021 09:31:36 +0200 Subject: [PATCH 08/20] formt --- docs/source/reinforce_learn.rst | 12 ++--- pl_bolts/models/rl/common/agents.py | 9 ++-- pl_bolts/models/rl/common/distributions.py | 6 ++- pl_bolts/models/rl/common/networks.py | 17 +++---- pl_bolts/models/rl/sac_model.py | 49 +++++++------------ .../integration/test_actor_critic_models.py | 6 +-- tests/models/rl/unit/test_sac.py | 17 ++++--- 7 files changed, 50 insertions(+), 66 deletions(-) diff --git a/docs/source/reinforce_learn.rst b/docs/source/reinforce_learn.rst index 9ab85ac7e9..792b73ab4c 100644 --- a/docs/source/reinforce_learn.rst +++ b/docs/source/reinforce_learn.rst @@ -688,13 +688,13 @@ Paper authors: Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, Sergey Levine Original implementation by: `Jason Wang `_ -Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a -special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which -means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such -as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient. +Soft Actor Critic (SAC) is a powerful actor critic algorithm in reinforcement learning. Unlike A2C, SAC's policy outputs a +special continuous distribution for actions, and its critic estimates the Q value instead of the state value, which +means it now takes in not only states but also actions. The new actor allows SAC to support continuous action tasks such +as controlling robots, and the new critic allows SAC to support off-policy learning which is more sample efficient. -The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards. -The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the +The actor has a new objective to maximize entropy to encourage exploration while maximizing the expected rewards. +The critic uses two separate Q functions to "mitigate positive bias" during training by picking the minimum of the two as the predicted Q value. Since SAC is off-policy, its algorithm's training step is quite similar to DQN: diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index 15df03f128..b0c808f93a 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -142,7 +142,8 @@ def __call__(self, states: Tensor, device: str) -> List[int]: class SoftActorCriticAgent(Agent): """Actor-Critic based agent that returns a continuous action based on the policy""" - def __call__(self, states: torch.Tensor, device: str) -> List[float]: + + def __call__(self, states: Tensor, device: str) -> List[float]: """ Takes in the current state and returns the action based on the agents policy @@ -156,7 +157,7 @@ def __call__(self, states: torch.Tensor, device: str) -> List[float]: if not isinstance(states, list): states = [states] - if not isinstance(states, torch.Tensor): + if not isinstance(states, Tensor): states = torch.tensor(states, device=device) dist = self.net(states) @@ -164,7 +165,7 @@ def __call__(self, states: torch.Tensor, device: str) -> List[float]: return actions - def get_action(self, states: torch.Tensor, device: str) -> List[float]: + def get_action(self, states: Tensor, device: str) -> List[float]: """ Get the action greedily (without sampling) @@ -178,7 +179,7 @@ def get_action(self, states: torch.Tensor, device: str) -> List[float]: if not isinstance(states, list): states = [states] - if not isinstance(states, torch.Tensor): + if not isinstance(states, Tensor): states = torch.tensor(states, device=device) actions = [self.net.get_action(states).cpu().numpy()] diff --git a/pl_bolts/models/rl/common/distributions.py b/pl_bolts/models/rl/common/distributions.py index 33482a8566..9eaeda806c 100644 --- a/pl_bolts/models/rl/common/distributions.py +++ b/pl_bolts/models/rl/common/distributions.py @@ -10,6 +10,7 @@ class TanhMultivariateNormal(torch.distributions.MultivariateNormal): X = action_scale * tanh(Z) + action_bias Z ~ Normal(mean, variance) """ + def __init__(self, action_bias, action_scale, **kwargs): super().__init__(**kwargs) @@ -40,7 +41,7 @@ def log_prob_with_z(self, value, z): """ value = (value - self.action_bias) / self.action_scale z_logprob = super().log_prob(z) - correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) + correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) return z_logprob - correction def rsample_and_log_prob(self, sample_shape=torch.Size()): @@ -53,12 +54,13 @@ def rsample_and_log_prob(self, sample_shape=torch.Size()): z = super().rsample() z_logprob = super().log_prob(z) value = torch.tanh(z) - correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) + correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) return self.action_scale * value + self.action_bias, z_logprob - correction """ Some override methods """ + def rsample(self, sample_shape=torch.Size()): fz, z = self.rsample_with_z(sample_shape) return fz diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index c0f9ab1ec1..6a2a173d5f 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -6,7 +6,7 @@ import numpy as np import torch -from torch import nn, Tensor +from torch import FloatTensor, nn, Tensor from torch.nn import functional as F from pl_bolts.models.rl.common.distributions import TanhMultivariateNormal @@ -98,6 +98,7 @@ class ContinuousMLP(nn.Module): """ MLP network that outputs continuous value via Gaussian distribution """ + def __init__( self, input_shape: Tuple[int], @@ -119,15 +120,12 @@ def __init__( self.action_scale = action_scale self.shared_net = nn.Sequential( - nn.Linear(input_shape[0], hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, hidden_size), - nn.ReLU() + nn.Linear(input_shape[0], hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU() ) self.mean_layer = nn.Linear(hidden_size, n_actions) self.logstd_layer = nn.Linear(hidden_size, n_actions) - def forward(self, x: torch.FloatTensor) -> TanhMultivariateNormal: + def forward(self, x: FloatTensor) -> TanhMultivariateNormal: """ Forward pass through network. Calculates the action distribution @@ -141,13 +139,10 @@ def forward(self, x: torch.FloatTensor) -> TanhMultivariateNormal: logstd = torch.clamp(self.logstd_layer(x), -20, 2) batch_scale_tril = torch.diag_embed(torch.exp(logstd)) return TanhMultivariateNormal( - action_bias=self.action_bias, - action_scale=self.action_scale, - loc=batch_mean, - scale_tril=batch_scale_tril + action_bias=self.action_bias, action_scale=self.action_scale, loc=batch_mean, scale_tril=batch_scale_tril ) - def get_action(self, x: torch.FloatTensor) -> torch.Tensor: + def get_action(self, x: FloatTensor) -> Tensor: """ Get the action greedily (without sampling) diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index f12f291405..033bf5cbb5 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -5,13 +5,13 @@ from typing import Dict, List, Tuple import numpy as np -import pytorch_lightning as pl import torch -from pytorch_lightning import seed_everything +from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from torch import optim as optim -from torch.optim.optimizer import Optimizer +from torch import Tensor from torch.nn import functional as F +from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset @@ -28,7 +28,8 @@ Env = object -class SAC(pl.LightningModule): +class SAC(LightningModule): + def __init__( self, env: str, @@ -134,13 +135,7 @@ def populate(self, warm_start: int) -> None: for _ in range(warm_start): action = self.agent(self.state, self.device) next_state, reward, done, _ = self.env.step(action[0]) - exp = Experience( - state=self.state, - action=action[0], - reward=reward, - done=done, - new_state=next_state - ) + exp = Experience(state=self.state, action=action[0], reward=reward, done=done, new_state=next_state) self.buffer.append(exp) self.state = next_state @@ -151,12 +146,7 @@ def build_networks(self) -> None: """Initializes the SAC policy and q networks (with targets)""" action_bias = torch.from_numpy((self.env.action_space.high + self.env.action_space.low) / 2) action_scale = torch.from_numpy((self.env.action_space.high - self.env.action_space.low) / 2) - self.policy = ContinuousMLP( - self.obs_shape, - self.n_actions, - action_bias=action_bias, - action_scale=action_scale - ) + self.policy = ContinuousMLP(self.obs_shape, self.n_actions, action_bias=action_bias, action_scale=action_scale) concat_shape = [self.obs_shape[0] + self.n_actions] self.q1 = MLP(concat_shape, 1) @@ -176,12 +166,10 @@ def soft_update_target(self, q_net, target_net): target_net: the target (q) network """ for q_param, target_param in zip(q_net.parameters(), target_net.parameters()): - target_param.data.copy_( - (1.0 - self.hparams.target_alpha) * target_param.data + - self.hparams.target_alpha * q_param - ) + target_param.data.copy_((1.0 - self.hparams.target_alpha) * target_param.data + + self.hparams.target_alpha * q_param) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: """ Passes in a state x through the network and gets the q_values of each action as an output @@ -194,7 +182,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.policy(x).sample() return output - def train_batch(self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def train_batch(self, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Contains the logic for generating a new batch of data to be passed to the DataLoader @@ -236,10 +224,7 @@ def train_batch(self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch if self.total_steps % self.batches_per_epoch == 0: break - def loss( - self, - batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def loss(self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]: """ Calculates the loss for SAC which contains a total of 3 losses @@ -283,7 +268,7 @@ def loss( return policy_loss, q1_loss, q2_loss - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _, optimizer_idx): + def training_step(self, batch: Tuple[Tensor, Tensor], _, optimizer_idx): """ Carries out a single step through the environment to update the replay buffer. Then calculates loss based on the minibatch recieved @@ -323,13 +308,13 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _, optimizer_i "episode_steps": self.total_episode_steps[-1] }) - def test_step(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: """Evaluate the agent for 10 episodes""" test_reward = self.run_n_episodes(self.test_env, 1) avg_reward = sum(test_reward) / len(test_reward) return {"test_reward": avg_reward} - def test_epoch_end(self, outputs) -> Dict[str, torch.Tensor]: + def test_epoch_end(self, outputs) -> Dict[str, Tensor]: """Log the avg of the test results""" rewards = [x["test_reward"] for x in outputs] avg_reward = sum(rewards) / len(rewards) @@ -415,7 +400,7 @@ def cli_main(): parser = argparse.ArgumentParser(add_help=False) # trainer args - parser = pl.Trainer.add_argparse_args(parser) + parser = Trainer.add_argparse_args(parser) # model args parser = SAC.add_model_specific_args(parser) @@ -427,7 +412,7 @@ def cli_main(): checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True) seed_everything(123) - trainer = pl.Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback) + trainer = Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback) trainer.fit(model) diff --git a/tests/models/rl/integration/test_actor_critic_models.py b/tests/models/rl/integration/test_actor_critic_models.py index f3a88465c6..cb5d3ade95 100644 --- a/tests/models/rl/integration/test_actor_critic_models.py +++ b/tests/models/rl/integration/test_actor_critic_models.py @@ -1,6 +1,6 @@ import argparse -import pytorch_lightning as pl +from pytorch_lightning import Trainer from pl_bolts.models.rl.sac_model import SAC @@ -9,7 +9,7 @@ def test_sac(): """Smoke test that the SAC model runs""" parent_parser = argparse.ArgumentParser(add_help=False) - parent_parser = pl.Trainer.add_argparse_args(parent_parser) + parent_parser = Trainer.add_argparse_args(parent_parser) parent_parser = SAC.add_model_specific_args(parent_parser) args_list = [ "--warm_start_size", @@ -23,7 +23,7 @@ def test_sac(): ] hparams = parent_parser.parse_args(args_list) - trainer = pl.Trainer( + trainer = Trainer( gpus=hparams.gpus, max_steps=100, max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early diff --git a/tests/models/rl/unit/test_sac.py b/tests/models/rl/unit/test_sac.py index 980064fb55..1e8e0565d2 100644 --- a/tests/models/rl/unit/test_sac.py +++ b/tests/models/rl/unit/test_sac.py @@ -1,6 +1,7 @@ import argparse import torch +from torch import Tensor from pl_bolts.models.rl.sac_model import SAC @@ -27,9 +28,9 @@ def test_sac_loss(): policy_loss, q1_loss, q2_loss = model.loss(batch) - assert isinstance(policy_loss, torch.Tensor) - assert isinstance(q1_loss, torch.Tensor) - assert isinstance(q2_loss, torch.Tensor) + assert isinstance(policy_loss, Tensor) + assert isinstance(q1_loss, Tensor) + assert isinstance(q2_loss, Tensor) def test_sac_train_batch(): @@ -52,8 +53,8 @@ def test_sac_train_batch(): assert len(batch) == 5 assert len(batch[0]) == model.hparams.batch_size assert isinstance(batch, list) - assert isinstance(batch[0], torch.Tensor) - assert isinstance(batch[1], torch.Tensor) - assert isinstance(batch[2], torch.Tensor) - assert isinstance(batch[3], torch.Tensor) - assert isinstance(batch[4], torch.Tensor) + assert isinstance(batch[0], Tensor) + assert isinstance(batch[1], Tensor) + assert isinstance(batch[2], Tensor) + assert isinstance(batch[3], Tensor) + assert isinstance(batch[4], Tensor) From 557ea57f5f3482eacdd2a369269a67012ab41157 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 24 Jun 2021 09:36:53 +0200 Subject: [PATCH 09/20] Apply suggestions from code review --- pl_bolts/models/rl/common/distributions.py | 3 --- tests/models/rl/unit/test_sac.py | 6 +----- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/pl_bolts/models/rl/common/distributions.py b/pl_bolts/models/rl/common/distributions.py index 9eaeda806c..e29bc4d81f 100644 --- a/pl_bolts/models/rl/common/distributions.py +++ b/pl_bolts/models/rl/common/distributions.py @@ -57,9 +57,6 @@ def rsample_and_log_prob(self, sample_shape=torch.Size()): correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) return self.action_scale * value + self.action_bias, z_logprob - correction - """ - Some override methods - """ def rsample(self, sample_shape=torch.Size()): fz, z = self.rsample_with_z(sample_shape) diff --git a/tests/models/rl/unit/test_sac.py b/tests/models/rl/unit/test_sac.py index 1e8e0565d2..9a19f36e43 100644 --- a/tests/models/rl/unit/test_sac.py +++ b/tests/models/rl/unit/test_sac.py @@ -53,8 +53,4 @@ def test_sac_train_batch(): assert len(batch) == 5 assert len(batch[0]) == model.hparams.batch_size assert isinstance(batch, list) - assert isinstance(batch[0], Tensor) - assert isinstance(batch[1], Tensor) - assert isinstance(batch[2], Tensor) - assert isinstance(batch[3], Tensor) - assert isinstance(batch[4], Tensor) + assert all(isinstance(batch[i], Tensor) for i in range(5)) From ad47e3455aa59e65f5434d91b82ae454c44d852c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Jun 2021 07:37:19 +0000 Subject: [PATCH 10/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/models/rl/common/distributions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pl_bolts/models/rl/common/distributions.py b/pl_bolts/models/rl/common/distributions.py index e29bc4d81f..de3aaf93a3 100644 --- a/pl_bolts/models/rl/common/distributions.py +++ b/pl_bolts/models/rl/common/distributions.py @@ -57,7 +57,6 @@ def rsample_and_log_prob(self, sample_shape=torch.Size()): correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) return self.action_scale * value + self.action_bias, z_logprob - correction - def rsample(self, sample_shape=torch.Size()): fz, z = self.rsample_with_z(sample_shape) return fz From d81e8e0a31d68fde9ef2a8a71d735458a97368fb Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Tue, 6 Jul 2021 18:23:16 -0700 Subject: [PATCH 11/20] use hyperparameters in hparams --- pl_bolts/models/rl/sac_model.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index f12f291405..5ff01a753b 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -73,14 +73,6 @@ def __init__( self.agent = SoftActorCriticAgent(self.policy) # Hyperparameters - self.sync_rate = sync_rate - self.gamma = gamma - self.batch_size = batch_size - self.replay_size = replay_size - self.warm_start_size = warm_start_size - self.batches_per_epoch = batches_per_epoch - self.n_steps = n_steps - self.save_hyperparameters() # Metrics @@ -227,13 +219,13 @@ def train_batch(self, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch episode_steps = 0 episode_reward = 0 - states, actions, rewards, dones, new_states = self.buffer.sample(self.batch_size) + states, actions, rewards, dones, new_states = self.buffer.sample(self.hparams.batch_size) for idx, _ in enumerate(dones): yield states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx] # Simulates epochs - if self.total_steps % self.batches_per_epoch == 0: + if self.total_steps % self.hparams.batches_per_epoch == 0: break def loss( @@ -276,7 +268,7 @@ def loss( next_q1_values = self.target_q1(new_next_states_actions) next_q2_values = self.target_q2(new_next_states_actions) next_qmin_values = torch.min(next_q1_values, next_q2_values) - new_next_logprobs - target_values = rewards + (1. - dones) * self.gamma * next_qmin_values + target_values = rewards + (1. - dones) * self.hparams.gamma * next_qmin_values q1_loss = F.mse_loss(q1_values, target_values) q2_loss = F.mse_loss(q2_values, target_values) @@ -309,7 +301,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _, optimizer_i q2_optim.step() # Soft update of target network - if self.global_step % self.sync_rate == 0: + if self.global_step % self.hparams.sync_rate == 0: self.soft_update_target(self.q1, self.target_q1) self.soft_update_target(self.q2, self.target_q2) @@ -338,11 +330,11 @@ def test_epoch_end(self, outputs) -> Dict[str, torch.Tensor]: def _dataloader(self) -> DataLoader: """Initialize the Replay Buffer dataset used for retrieving experiences""" - self.buffer = MultiStepBuffer(self.replay_size, self.n_steps) - self.populate(self.warm_start_size) + self.buffer = MultiStepBuffer(self.hparams.replay_size, self.hparams.n_steps) + self.populate(self.hparams.warm_start_size) self.dataset = ExperienceSourceDataset(self.train_batch) - return DataLoader(dataset=self.dataset, batch_size=self.batch_size) + return DataLoader(dataset=self.dataset, batch_size=self.hparams.batch_size) def train_dataloader(self) -> DataLoader: """Get train loader""" From d101d50b7b400f9e5b4db83b29a09abd7e66451e Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Tue, 6 Jul 2021 18:31:20 -0700 Subject: [PATCH 12/20] Add CHANGELOG --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a247597da4..ad2d28e359 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added + - Added Soft Actor Critic (SAC) Model [#627](https://github.com/PyTorchLightning/lightning-bolts/pull/627)) + ### Changed From 43daba334a81177945e2df38bfc7f4d58ec517c0 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Tue, 20 Jul 2021 09:53:28 -0700 Subject: [PATCH 13/20] fix test --- tests/models/rl/integration/test_actor_critic_models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/rl/integration/test_actor_critic_models.py b/tests/models/rl/integration/test_actor_critic_models.py index cb5d3ade95..73d314de2f 100644 --- a/tests/models/rl/integration/test_actor_critic_models.py +++ b/tests/models/rl/integration/test_actor_critic_models.py @@ -32,5 +32,3 @@ def test_sac(): ) model = SAC(**hparams.__dict__) result = trainer.fit(model) - - assert result == 1 From 25763337a8c576c5afa39ce0c8a4e70fc2da381c Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Sat, 31 Jul 2021 20:50:48 -0700 Subject: [PATCH 14/20] fix format --- tests/models/rl/integration/test_actor_critic_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/rl/integration/test_actor_critic_models.py b/tests/models/rl/integration/test_actor_critic_models.py index 73d314de2f..a46b05a23b 100644 --- a/tests/models/rl/integration/test_actor_critic_models.py +++ b/tests/models/rl/integration/test_actor_critic_models.py @@ -31,4 +31,4 @@ def test_sac(): fast_dev_run=True ) model = SAC(**hparams.__dict__) - result = trainer.fit(model) + trainer.fit(model) From 73a13d1176f78953b0def4dc0bf638c35e271b8e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Aug 2021 17:46:58 +0000 Subject: [PATCH 15/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/models/rl/__init__.py | 3 +- pl_bolts/models/rl/common/agents.py | 14 ++- pl_bolts/models/rl/common/distributions.py | 25 ++--- pl_bolts/models/rl/common/networks.py | 19 ++-- pl_bolts/models/rl/sac_model.py | 92 +++++++++---------- .../integration/test_actor_critic_models.py | 8 +- tests/models/rl/test_scripts.py | 14 +-- tests/models/rl/unit/test_sac.py | 4 +- 8 files changed, 83 insertions(+), 96 deletions(-) diff --git a/pl_bolts/models/rl/__init__.py b/pl_bolts/models/rl/__init__.py index bee9cb496e..e4c9244e03 100644 --- a/pl_bolts/models/rl/__init__.py +++ b/pl_bolts/models/rl/__init__.py @@ -16,6 +16,5 @@ "NoisyDQN", "PERDQN", "Reinforce", - "SAC" - "VanillaPolicyGradient", + "SAC" "VanillaPolicyGradient", ] diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index c211758801..cbefa5d635 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -164,12 +164,11 @@ def __call__(self, states: Tensor, device: str) -> List[int]: class SoftActorCriticAgent(Agent): - """Actor-Critic based agent that returns a continuous action based on the policy""" + """Actor-Critic based agent that returns a continuous action based on the policy.""" def __call__(self, states: Tensor, device: str) -> List[float]: - """ - Takes in the current state and returns the action based on the agents policy - + """Takes in the current state and returns the action based on the agents policy. + Args: states: current state of the environment device: the device used for the current batch @@ -189,8 +188,7 @@ def __call__(self, states: Tensor, device: str) -> List[float]: return actions def get_action(self, states: Tensor, device: str) -> List[float]: - """ - Get the action greedily (without sampling) + """Get the action greedily (without sampling) Args: states: current state of the environment @@ -206,5 +204,5 @@ def get_action(self, states: Tensor, device: str) -> List[float]: states = torch.tensor(states, device=device) actions = [self.net.get_action(states).cpu().numpy()] - - return actions \ No newline at end of file + + return actions diff --git a/pl_bolts/models/rl/common/distributions.py b/pl_bolts/models/rl/common/distributions.py index de3aaf93a3..0374928c24 100644 --- a/pl_bolts/models/rl/common/distributions.py +++ b/pl_bolts/models/rl/common/distributions.py @@ -1,14 +1,12 @@ -""" -Distributions used in some continuous RL algorithms -""" +"""Distributions used in some continuous RL algorithms.""" import torch class TanhMultivariateNormal(torch.distributions.MultivariateNormal): - """ - The distribution of X is an affine of tanh applied on a normal distribution - X = action_scale * tanh(Z) + action_bias - Z ~ Normal(mean, variance) + """The distribution of X is an affine of tanh applied on a normal distribution. + + X = action_scale * tanh(Z) + action_bias + Z ~ Normal(mean, variance) """ def __init__(self, action_bias, action_scale, **kwargs): @@ -18,8 +16,7 @@ def __init__(self, action_bias, action_scale, **kwargs): self.action_scale = action_scale def rsample_with_z(self, sample_shape=torch.Size()): - """ - Samples X using reparametrization trick with the intermediate variable Z + """Samples X using reparametrization trick with the intermediate variable Z. Returns: Sampled X and Z @@ -28,8 +25,7 @@ def rsample_with_z(self, sample_shape=torch.Size()): return self.action_scale * torch.tanh(z) + self.action_bias, z def log_prob_with_z(self, value, z): - """ - Computes the log probability of a sampled X + """Computes the log probability of a sampled X. Refer to the original paper of SAC for more details in equation (20), (21) @@ -41,12 +37,11 @@ def log_prob_with_z(self, value, z): """ value = (value - self.action_bias) / self.action_scale z_logprob = super().log_prob(z) - correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) + correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) return z_logprob - correction def rsample_and_log_prob(self, sample_shape=torch.Size()): - """ - Samples X and computes the log probability of the sample + """Samples X and computes the log probability of the sample. Returns: Sampled X and log probability @@ -54,7 +49,7 @@ def rsample_and_log_prob(self, sample_shape=torch.Size()): z = super().rsample() z_logprob = super().log_prob(z) value = torch.tanh(z) - correction = torch.log(self.action_scale * (1 - value**2) + 1e-7).sum(1) + correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) return self.action_scale * value + self.action_bias, z_logprob - correction def rsample(self, sample_shape=torch.Size()): diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index 065e894ced..476aae54ea 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -4,7 +4,7 @@ import numpy as np import torch -from torch import FloatTensor, nn, Tensor +from torch import FloatTensor, Tensor, nn from torch.distributions import Categorical, Normal from torch.nn import functional as F @@ -87,9 +87,7 @@ def forward(self, input_x): class ContinuousMLP(nn.Module): - """ - MLP network that outputs continuous value via Gaussian distribution - """ + """MLP network that outputs continuous value via Gaussian distribution.""" def __init__( self, @@ -97,7 +95,7 @@ def __init__( n_actions: int, hidden_size: int = 128, action_bias: int = 0, - action_scale: int = 1 + action_scale: int = 1, ): """ Args: @@ -107,7 +105,7 @@ def __init__( action_bias: the center of the action space action_scale: the scale of the action space """ - super(ContinuousMLP, self).__init__() + super().__init__() self.action_bias = action_bias self.action_scale = action_scale @@ -118,14 +116,13 @@ def __init__( self.logstd_layer = nn.Linear(hidden_size, n_actions) def forward(self, x: FloatTensor) -> TanhMultivariateNormal: - """ - Forward pass through network. Calculates the action distribution + """Forward pass through network. Calculates the action distribution. Args: x: input to network Returns: action distribution - """ + """ x = self.shared_net(x.float()) batch_mean = self.mean_layer(x) logstd = torch.clamp(self.logstd_layer(x), -20, 2) @@ -135,8 +132,7 @@ def forward(self, x: FloatTensor) -> TanhMultivariateNormal: ) def get_action(self, x: FloatTensor) -> Tensor: - """ - Get the action greedily (without sampling) + """Get the action greedily (without sampling) Args: x: input to network @@ -147,6 +143,7 @@ def get_action(self, x: FloatTensor) -> Tensor: batch_mean = self.mean_layer(x) return self.action_scale * torch.tanh(batch_mean) + self.action_bias + class ActorCriticMLP(nn.Module): """MLP network with heads for actor and critic.""" diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index 8901dbc9b2..441f8ca9d7 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -1,15 +1,13 @@ -""" -Soft Actor Critic -""" +"""Soft Actor Critic.""" import argparse from typing import Dict, List, Tuple import numpy as np import torch -from pytorch_lightning import LightningModule, seed_everything, Trainer +from pytorch_lightning import LightningModule, Trainer, seed_everything from pytorch_lightning.callbacks import ModelCheckpoint -from torch import optim as optim from torch import Tensor +from torch import optim as optim from torch.nn import functional as F from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -17,19 +15,18 @@ from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset from pl_bolts.models.rl.common.agents import SoftActorCriticAgent from pl_bolts.models.rl.common.memory import MultiStepBuffer -from pl_bolts.models.rl.common.networks import ContinuousMLP, MLP +from pl_bolts.models.rl.common.networks import MLP, ContinuousMLP from pl_bolts.utils import _GYM_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg if _GYM_AVAILABLE: import gym else: # pragma: no cover - warn_missing_pkg('gym') + warn_missing_pkg("gym") Env = object class SAC(LightningModule): - def __init__( self, env: str, @@ -88,15 +85,14 @@ def __init__( for _ in range(avg_reward_len): self.total_rewards.append(torch.tensor(min_episode_reward, device=self.device)) - self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:])) + self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len :])) self.state = self.env.reset() self.automatic_optimization = False def run_n_episodes(self, env, n_epsiodes: int = 1) -> List[int]: - """ - Carries out N episodes of the environment with the current agent without exploration + """Carries out N episodes of the environment with the current agent without exploration. Args: env: environment to use, either train environment or test environment @@ -120,7 +116,7 @@ def run_n_episodes(self, env, n_epsiodes: int = 1) -> List[int]: return total_rewards def populate(self, warm_start: int) -> None: - """Populates the buffer with initial experience""" + """Populates the buffer with initial experience.""" if warm_start > 0: self.state = self.env.reset() @@ -149,8 +145,8 @@ def build_networks(self) -> None: self.target_q2.load_state_dict(self.q2.state_dict()) def soft_update_target(self, q_net, target_net): - """ - Update the weights in target network using a weighted sum + """Update the weights in target network using a weighted sum. + w_target := (1-a) * w_target + a * w_q Args: @@ -158,12 +154,12 @@ def soft_update_target(self, q_net, target_net): target_net: the target (q) network """ for q_param, target_param in zip(q_net.parameters(), target_net.parameters()): - target_param.data.copy_((1.0 - self.hparams.target_alpha) * target_param.data - + self.hparams.target_alpha * q_param) + target_param.data.copy_( + (1.0 - self.hparams.target_alpha) * target_param.data + self.hparams.target_alpha * q_param + ) def forward(self, x: Tensor) -> Tensor: - """ - Passes in a state x through the network and gets the q_values of each action as an output + """Passes in a state x through the network and gets the q_values of each action as an output. Args: x: environment state @@ -174,9 +170,10 @@ def forward(self, x: Tensor) -> Tensor: output = self.policy(x).sample() return output - def train_batch(self, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: - """ - Contains the logic for generating a new batch of data to be passed to the DataLoader + def train_batch( + self, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Contains the logic for generating a new batch of data to be passed to the DataLoader. Returns: yields a Experience tuple containing the state, action, reward, done and next_state. @@ -202,7 +199,7 @@ def train_batch(self, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: self.done_episodes += 1 self.total_rewards.append(episode_reward) self.total_episode_steps.append(episode_steps) - self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:])) + self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len :])) self.state = self.env.reset() episode_steps = 0 episode_reward = 0 @@ -217,8 +214,7 @@ def train_batch(self, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: break def loss(self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]: - """ - Calculates the loss for SAC which contains a total of 3 losses + """Calculates the loss for SAC which contains a total of 3 losses. Args: batch: a batch of states, actions, rewards, dones, and next states @@ -253,7 +249,7 @@ def loss(self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tuple[Te next_q1_values = self.target_q1(new_next_states_actions) next_q2_values = self.target_q2(new_next_states_actions) next_qmin_values = torch.min(next_q1_values, next_q2_values) - new_next_logprobs - target_values = rewards + (1. - dones) * self.hparams.gamma * next_qmin_values + target_values = rewards + (1.0 - dones) * self.hparams.gamma * next_qmin_values q1_loss = F.mse_loss(q1_values, target_values) q2_loss = F.mse_loss(q2_values, target_values) @@ -261,9 +257,8 @@ def loss(self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tuple[Te return policy_loss, q1_loss, q2_loss def training_step(self, batch: Tuple[Tensor, Tensor], _, optimizer_idx): - """ - Carries out a single step through the environment to update the replay buffer. - Then calculates loss based on the minibatch recieved + """Carries out a single step through the environment to update the replay buffer. Then calculates loss + based on the minibatch recieved. Args: batch: current mini batch of replay data @@ -290,31 +285,33 @@ def training_step(self, batch: Tuple[Tensor, Tensor], _, optimizer_idx): self.soft_update_target(self.q1, self.target_q1) self.soft_update_target(self.q2, self.target_q2) - self.log_dict({ - "total_reward": self.total_rewards[-1], - "avg_reward": self.avg_rewards, - "policy_loss": policy_loss, - "q1_loss": q1_loss, - "q2_loss": q2_loss, - "episodes": self.done_episodes, - "episode_steps": self.total_episode_steps[-1] - }) + self.log_dict( + { + "total_reward": self.total_rewards[-1], + "avg_reward": self.avg_rewards, + "policy_loss": policy_loss, + "q1_loss": q1_loss, + "q2_loss": q2_loss, + "episodes": self.done_episodes, + "episode_steps": self.total_episode_steps[-1], + } + ) def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: - """Evaluate the agent for 10 episodes""" + """Evaluate the agent for 10 episodes.""" test_reward = self.run_n_episodes(self.test_env, 1) avg_reward = sum(test_reward) / len(test_reward) return {"test_reward": avg_reward} def test_epoch_end(self, outputs) -> Dict[str, Tensor]: - """Log the avg of the test results""" + """Log the avg of the test results.""" rewards = [x["test_reward"] for x in outputs] avg_reward = sum(rewards) / len(rewards) self.log("avg_test_reward", avg_reward) return {"avg_test_reward": avg_reward} def _dataloader(self) -> DataLoader: - """Initialize the Replay Buffer dataset used for retrieving experiences""" + """Initialize the Replay Buffer dataset used for retrieving experiences.""" self.buffer = MultiStepBuffer(self.hparams.replay_size, self.hparams.n_steps) self.populate(self.hparams.warm_start_size) @@ -322,24 +319,25 @@ def _dataloader(self) -> DataLoader: return DataLoader(dataset=self.dataset, batch_size=self.hparams.batch_size) def train_dataloader(self) -> DataLoader: - """Get train loader""" + """Get train loader.""" return self._dataloader() def test_dataloader(self) -> DataLoader: - """Get test loader""" + """Get test loader.""" return self._dataloader() def configure_optimizers(self) -> Tuple[Optimizer]: - """ Initialize Adam optimizer""" + """Initialize Adam optimizer.""" policy_optim = optim.Adam(self.policy.parameters(), self.hparams.policy_learning_rate) q1_optim = optim.Adam(self.q1.parameters(), self.hparams.q_learning_rate) q2_optim = optim.Adam(self.q2.parameters(), self.hparams.q_learning_rate) return policy_optim, q1_optim, q2_optim @staticmethod - def add_model_specific_args(arg_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: - """ - Adds arguments for DQN model + def add_model_specific_args( + arg_parser: argparse.ArgumentParser, + ) -> argparse.ArgumentParser: + """Adds arguments for DQN model. Note: These params are fine tuned for Pong env. @@ -409,5 +407,5 @@ def cli_main(): trainer.fit(model) -if __name__ == '__main__': +if __name__ == "__main__": cli_main() diff --git a/tests/models/rl/integration/test_actor_critic_models.py b/tests/models/rl/integration/test_actor_critic_models.py index 09f82ef762..cb93f02b19 100644 --- a/tests/models/rl/integration/test_actor_critic_models.py +++ b/tests/models/rl/integration/test_actor_critic_models.py @@ -18,7 +18,6 @@ def test_a2c(): hparams = parent_parser.parse_args(args_list) trainer = Trainer( - gpus=0, max_steps=100, max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early @@ -28,8 +27,9 @@ def test_a2c(): model = AdvantageActorCritic(hparams.env) trainer.fit(model) + def test_sac(): - """Smoke test that the SAC model runs""" + """Smoke test that the SAC model runs.""" parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = Trainer.add_argparse_args(parent_parser) @@ -51,7 +51,7 @@ def test_sac(): max_steps=100, max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early val_check_interval=1, # This just needs 'some' value, does not effect training right now - fast_dev_run=True + fast_dev_run=True, ) model = SAC(**hparams.__dict__) - trainer.fit(model) \ No newline at end of file + trainer.fit(model) diff --git a/tests/models/rl/test_scripts.py b/tests/models/rl/test_scripts.py index a438c492e6..119f3c4285 100644 --- a/tests/models/rl/test_scripts.py +++ b/tests/models/rl/test_scripts.py @@ -148,16 +148,16 @@ def test_cli_run_rl_advantage_actor_critic(cli_args): cli_main() -@pytest.mark.parametrize('cli_args', [ - ' --env Pendulum-v0' - ' --max_steps 10' - ' --fast_dev_run 1' - ' --batch_size 10', -]) +@pytest.mark.parametrize( + "cli_args", + [ + " --env Pendulum-v0" " --max_steps 10" " --fast_dev_run 1" " --batch_size 10", + ], +) def test_cli_run_rl_soft_actor_critic(cli_args): """Test running CLI for an example with default params.""" from pl_bolts.models.rl.sac_model import cli_main - cli_args = cli_args.strip().split(' ') if cli_args else [] + cli_args = cli_args.strip().split(" ") if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() diff --git a/tests/models/rl/unit/test_sac.py b/tests/models/rl/unit/test_sac.py index 9a19f36e43..c668bd3b4f 100644 --- a/tests/models/rl/unit/test_sac.py +++ b/tests/models/rl/unit/test_sac.py @@ -7,7 +7,7 @@ def test_sac_loss(): - """Test the reinforce loss function""" + """Test the reinforce loss function.""" parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = SAC.add_model_specific_args(parent_parser) args_list = [ @@ -34,7 +34,7 @@ def test_sac_loss(): def test_sac_train_batch(): - """Tests that a single batch generates correctly""" + """Tests that a single batch generates correctly.""" parent_parser = argparse.ArgumentParser(add_help=False) parent_parser = SAC.add_model_specific_args(parent_parser) args_list = [ From be19c64133b98c4d0172495e91e005d398d32faf Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Fri, 13 Aug 2021 10:49:16 -0700 Subject: [PATCH 16/20] fix __init__ --- pl_bolts/models/rl/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pl_bolts/models/rl/__init__.py b/pl_bolts/models/rl/__init__.py index e4c9244e03..5fb0e78a3d 100644 --- a/pl_bolts/models/rl/__init__.py +++ b/pl_bolts/models/rl/__init__.py @@ -16,5 +16,6 @@ "NoisyDQN", "PERDQN", "Reinforce", - "SAC" "VanillaPolicyGradient", + "SAC", + "VanillaPolicyGradient", ] From 25aa7e085f5b948aafdc57a8786e95dce46193ca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Aug 2021 18:03:32 +0000 Subject: [PATCH 17/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pl_bolts/models/rl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/models/rl/__init__.py b/pl_bolts/models/rl/__init__.py index 5fb0e78a3d..36f6594566 100644 --- a/pl_bolts/models/rl/__init__.py +++ b/pl_bolts/models/rl/__init__.py @@ -5,7 +5,7 @@ from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN from pl_bolts.models.rl.per_dqn_model import PERDQN from pl_bolts.models.rl.reinforce_model import Reinforce -from pl_bolts.models.rl.sac_model import SAC # noqa: F401 +from pl_bolts.models.rl.sac_model import SAC from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient __all__ = [ From bfbae6bb0177ee1c30035d552f4cfbc0f349504c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 8 Sep 2021 10:14:46 +0100 Subject: [PATCH 18/20] Fix tests --- docs/source/reinforce_learn.rst | 2 +- pl_bolts/models/rl/sac_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/reinforce_learn.rst b/docs/source/reinforce_learn.rst index 4f720d7883..7f75db8995 100644 --- a/docs/source/reinforce_learn.rst +++ b/docs/source/reinforce_learn.rst @@ -849,4 +849,4 @@ Example:: trainer.fit(sac) .. autoclass:: pl_bolts.models.rl.SAC -:noindex: + :noindex: diff --git a/pl_bolts/models/rl/sac_model.py b/pl_bolts/models/rl/sac_model.py index 441f8ca9d7..6ee3f07784 100644 --- a/pl_bolts/models/rl/sac_model.py +++ b/pl_bolts/models/rl/sac_model.py @@ -402,7 +402,7 @@ def cli_main(): checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True) seed_everything(123) - trainer = Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback) + trainer = Trainer.from_argparse_args(args, deterministic=True, callbacks=checkpoint_callback) trainer.fit(model) From c0d16fda0434babe3baa3b065d8a73d8a33bc220 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 8 Sep 2021 10:25:34 +0100 Subject: [PATCH 19/20] Fix reference --- docs/source/reinforce_learn.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reinforce_learn.rst b/docs/source/reinforce_learn.rst index 7f75db8995..544c226cf6 100644 --- a/docs/source/reinforce_learn.rst +++ b/docs/source/reinforce_learn.rst @@ -781,7 +781,7 @@ Actor Critic Key Points: Soft Actor Critic (SAC) ^^^^^^^^^^^^^^^^^^^^^^^ -Soft Actor Critic model introduced in `Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor`_ +Soft Actor Critic model introduced in `Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor `__ Paper authors: Tuomas Haarnoja, Aurick Zhou, Pieter Abbeel, Sergey Levine Original implementation by: `Jason Wang `_ From 7a0e9440c9a51efe2ffc117ac1fe8f9142cf8eab Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 8 Sep 2021 10:46:25 +0100 Subject: [PATCH 20/20] Fix duplication --- docs/source/reinforce_learn.rst | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/docs/source/reinforce_learn.rst b/docs/source/reinforce_learn.rst index 544c226cf6..7431c6ad7e 100644 --- a/docs/source/reinforce_learn.rst +++ b/docs/source/reinforce_learn.rst @@ -767,16 +767,6 @@ Example:: -------------- -Actor-Critic Models -------------------- -The following models are based on Actor Critic. Actor Critic conbines the approaches of value-based learning (the DQN family) -and the policy-based learning (the PG family) by learning the value function as well as the policy distribution. This approach -updates the policy network according to the policy gradient, and updates the value network to fit the discounted rewards. - -Actor Critic Key Points: - - Actor outputs a distribution of actions for controlling the agent - - Critic outputs a value of current state for policy update suggestion - - The addition of critic allows the model to do n-step training instead of generating an entire trajectory Soft Actor Critic (SAC) ^^^^^^^^^^^^^^^^^^^^^^^