# Soft Actor-Critic (SAC)
Soft Actor-Critic (SAC) is an off policy algorithm that combines an actor-critic architecture with a maximum entropy framework. It is more sample-efficient and stable than prior algorithms (DDPG, TD3), making it suitable for complex, continuous action spaces.


# Introduction - Entropy
Entropy is a measure of randomness or uncertainty in a probability distribution. It is given by:

$$
H(p) = -\int p(x) \log p(x) dx = \mathbb{E}_{x \sim p} [-\log {p(x)}]
$$

In the context of reinforcement learning, entropy is used to measure the uncertainty or randomness in the policy. By maximizing entropy, the algorithm encourages exploration and prevents premature convergence to suboptimal policies.

We augment the reward with an entropy term to encourage exploration:

$$
J(\pi) = \sum_{t=0}^{T} \mathbb{E}_{s_t \sim p(s_t), a_t \sim \pi(\cdot|s_t)} [r(s_t, a_t) + \alpha H(\pi(\cdot|s_t))]
$$

Here, $\alpha$ is a constant that controls the trade-off between the reward and the entropy.



In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)


# Q Network

**Network Architecture is as mentioned in the paper

SAC uses the clipped double Q-network trick to better estimate the Q function.

Additionally, SAC uses a target network, an exponential moving average of the critic network, to provide a more stable estimate of the expected return when updating the Q function. It is updated slowly as a weighted average of the target and critic network parameters.

In [2]:
class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(QNetwork, self).__init__()
        
        self.q1 = nn.Sequential(
            nn.Linear(num_inputs + num_actions, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        self.q2 = nn.Sequential(
            nn.Linear(num_inputs + num_actions, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        self.apply(weights_init_)

    def forward(self, state, action):
        x1 = self.q1(torch.cat([state, action], 1))
        x2 = self.q2(torch.cat([state, action], 1))
        return x1, x2
            
        

# Policy Network

**Network Architecture is as mentioned in the paper

The policy network learns the mean and standard deviation of the Gaussian distribution over the actions. The mean is the expected action, and the standard deviation controls the uncertainty or randomness in the action distribution.

When sampling actions from the policy network, we use the reparameterization trick to allow for gradient descent on the policy network.

We use the tanh function to bound the action space to a finite interval (as stated in the original paper).

Thus, as shown in the paper, the log probability of the action is given by:

$$
\log \pi(\mathbf{a}|\mathbf{s}) = \log \mu(\mathbf{u}|\mathbf{s}_t) - \sum_{i=1}^D \log (1 - \tanh^2({u}_i))
$$

Additionally, we scale the action space to ensure that the output of the policy network is within the range of the action space.

When returning the scaled action, we have:
$$ P(\cdot | \pi_\theta(s)) = ((\mathcal{N}(tanh(\mu_{\theta}(s)), \sigma_{\theta}(s)) ) / 2.0) \times (h - l) + l  $$



In [3]:
LOG_SIG_MIN = -20
LOG_SIG_MAX = 2

class PolicyNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
        super(PolicyNetwork, self).__init__()
        
        self.mean = nn.Sequential(
            nn.Linear(num_inputs, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_actions)
        )

        self.log_std = nn.Sequential(
            nn.Linear(num_inputs, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_actions)
        )

        self.apply(weights_init_)

        self.action_scale = torch.tensor((action_space.high - action_space.low) / 2.)
        self.action_bias = torch.tensor((action_space.high + action_space.low) / 2.)

    def forward(self, state):
        mean = self.mean(state)
        log_std = self.log_std(state)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        action = normal.rsample() #rsample() samples from the normal distribution using the reparameterization trick
        compressed_action = torch.tanh(action)
        scaled_action = compressed_action * self.action_scale + self.action_bias
        log_prob = normal.log_prob(action)
        #We add a small constant to the log probability to prevent numerical instability
        log_prob = log_prob - torch.log(self.action_scale * (1 - compressed_action.pow(2)) + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return scaled_action, log_prob, mean
    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super().to(device)
        

# Soft Actor-Critic
**Hyperparameters are set as mentioned in the paper

gamma = 0.99

tau = 0.005

(alpha is initially set arbitrarily to 0.5 since it is learned)

network architecture: 2 hidden layers with 256 units each

optimizer: Adam

learning rate: $3 \cdot 10^{-4}$

During each gradient step, we update the parameters by calculating the loss as follows (updating gradients using torch backpropagation):

$$
\mathcal{L}(Q) = \mathbb{E}_{s_t, a_t, r_{t+1}, s_{t+1}, d \sim \mathcal{D}} [(Q(s_t, a_t) - (r + \gamma (1-d) (\text{min}Q_{target}(s_{t+1}, \pi(s_{t+1}))-\alpha \log \pi(a_{t+1}|s_{t+1}))))^2]
$$

$$
\mathcal{L}(\pi) = \mathbb{E}_{s_t \sim \mathcal{D}, \epsilon_t \sim \mathcal{N}} [\alpha \log \pi(a_t|s_t) - \text{min}Q(s_t, a_t)]
$$

$$
\mathcal{L}(\alpha) = \mathbb{E}_{s_t \sim \mathcal{D}} [-\alpha \log \pi(a_t|s_t) - \alpha \mathcal{H}_0]
$$

For the target networks, we use the following update rule:

$$
\theta_{Q'} = \tau \theta_{Q} + (1-\tau) \theta_{Q'}
$$



In [8]:
class SAC:
    def __init__(self, num_inputs, action_space, hidden_dim):
        self.critic = QNetwork(num_inputs, action_space.shape[0], hidden_dim)
        self.critic_target = QNetwork(num_inputs, action_space.shape[0], hidden_dim)
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=3e-4)
        self.policy = PolicyNetwork(num_inputs, action_space.shape[0], hidden_dim, action_space)
        self.policy_optim = optim.Adam(self.policy.parameters(), lr=3e-4)
        self.gamma = 0.99
        self.tau = 0.005
        self.alpha = 0.5
        # self.target_entropy = -torch.prod(torch.Tensor(action_space.shape[0]).to(device)).item()
        self.target_entropy = -action_space.shape[0]
        self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=3e-4)
    

    #When in inference mode, we just return the mean since we do not want to incorporate randomness
    def action(self, state, inference=False):
        if isinstance(state, (int, float)):
            state = torch.FloatTensor([state]).unsqueeze(0).to(device)
        else:
            state = torch.FloatTensor([state]).to(device)
        if inference:
            _, _, action = self.policy.sample(state)
        else:
            action, _, _ = self.policy.sample(state)
        return action.detach().numpy()[0]
    
    def update(self, memory, batch_size):
        states, actions, next_states, rewards, dones = memory.sample(batch_size)
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        next_states = torch.FloatTensor(next_states).to(device)
        rewards = torch.FloatTensor(rewards).to(device).unsqueeze(1)
        dones = torch.FloatTensor(dones).to(device).unsqueeze(1)
        
        #update critic
        #We use torch.no_grad() here since the target network does not require gradients
        with torch.no_grad():
            next_action, next_log_pi, _ = self.policy.sample(next_states)
            q1_target, q2_target = self.critic_target(next_states, next_action)
            q_target = torch.min(q1_target, q2_target) - self.alpha * next_log_pi
            next_target_Q = rewards + ((1 - dones) * self.gamma * (q_target))
        q1, q2 = self.critic(states, actions)
        q1_loss = F.mse_loss(q1, next_target_Q)
        q2_loss = F.mse_loss(q2, next_target_Q)
        q_loss = q1_loss + q2_loss

        self.critic_optim.zero_grad()
        q_loss.backward()
        self.critic_optim.step()
        
        #update policy
        pi, log_pi, _ = self.policy.sample(states)
        q1_pi, q2_pi = self.critic(states, pi)
        min_q_pi = torch.min(q1_pi, q2_pi)
        
        policy_loss = (self.alpha * log_pi - min_q_pi).mean()
        
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()
        
        #update alpha
        alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
        
        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        
        self.alpha = self.log_alpha.exp()

        #update target network
        for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)

# Replay Buffer
The replay buffer is used to store past experiences. We use a circular buffer to store the experiences, ejecting old data when the buffer is full.

In [9]:
import random
import numpy as np

class ReplayBuffer:
    def __init__(self, buffer_size, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        self.buffer_size = buffer_size
        self.buffer = []
        self.position = 0

    def add(self, state, action, next_state, reward, done):
        if len(self.buffer) < self.buffer_size:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, next_state, reward, done)
        self.position = (self.position + 1) % self.buffer_size

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, next_state, reward, done = map(np.stack, zip(*batch))
        return state, action, next_state, reward, done

    def __len__(self):
        return len(self.buffer)


# Training
**Hyperparameters are set as mentioned in the paper.

buffer_size = 1000000
batch_size = 256

We use the Pendulum-v1 environment from OpenAI Gym.

When training, sample an action from the policy network, and add the experience to the replay buffer. If the replay buffer has enough experiences, sample a batch of experiences from the replay buffer, and update the critic network, policy network, and entropy coefficient. Then, add the new data to the replay buffer.


In [10]:
import gym
import numpy as np
import torch

env = gym.make('Pendulum-v1')
torch.manual_seed(1)
np.random.seed(1)
env.reset(seed=1)

agent = SAC(env.observation_space.shape[0], env.action_space, 64)
memory = ReplayBuffer(1000000, 1)
batch_size = 256

episode_rewards = []

for i in range(200):
    episode_reward = 0
    episode_steps = 0
    done = False
    state, _ = env.reset()
    while not done:
        action = agent.action(state, inference=False)
        if len(memory) > batch_size:
            agent.update(memory, batch_size)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        episode_reward += reward
        episode_steps += 1

        memory.add(state, action, next_state, reward, done)
        state = next_state
    
    episode_rewards.append(episode_reward)
    print(f"Episode {i}: | Episode Steps {episode_steps} | Episode Reward {episode_reward}")



  if not isinstance(terminated, (bool, np.bool8)):


Episode 0: | Episode Steps 200 | Episode Reward -1288.2334574071106
Episode 1: | Episode Steps 200 | Episode Reward -1007.8151008267273
Episode 2: | Episode Steps 200 | Episode Reward -1367.0768384295882
Episode 3: | Episode Steps 200 | Episode Reward -1062.141091587513
Episode 4: | Episode Steps 200 | Episode Reward -1556.5890641205967
Episode 5: | Episode Steps 200 | Episode Reward -1400.8870330406196
Episode 6: | Episode Steps 200 | Episode Reward -1575.1278704114347
Episode 7: | Episode Steps 200 | Episode Reward -1757.998088904695
Episode 8: | Episode Steps 200 | Episode Reward -1687.8123303683117
Episode 9: | Episode Steps 200 | Episode Reward -1480.7326791162154
Episode 10: | Episode Steps 200 | Episode Reward -1081.431146458793
Episode 11: | Episode Steps 200 | Episode Reward -1706.3001678310748
Episode 12: | Episode Steps 200 | Episode Reward -906.3868484994194
Episode 13: | Episode Steps 200 | Episode Reward -1134.7879343273214
Episode 14: | Episode Steps 200 | Episode Reward

Plot the rewards over time

In [None]:
import matplotlib.pyplot as plt

episodes = range(1, len(episode_rewards) + 1)
rewards = episode_rewards

plt.figure(figsize=(12, 6))
plt.plot(episodes, rewards)
plt.title('Episode Rewards over Time')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.show()

print(f"Maximum reward achieved: {np.max(rewards):.2f}")
print(f"Minimum reward achieved: {np.min(rewards):.2f}")


Visualize the trained agent in the environment, turn inference mode on


In [None]:
import time

def visualize_policy(agent, env, max_steps=200):
    state, _ = env.reset()
    total_reward = 0
    
    for _ in range(max_steps):
        action = agent.action(state, inference=True)
        next_state, reward, terminated, truncated, _ = env.step(action)
        total_reward += reward
        
        env.render()  # This will display the environment in a separate window
        time.sleep(0.01)  # Add a small delay to make the visualization visible
        
        if terminated or truncated:
            break
        
        state = next_state
    
    env.close()
    return total_reward

# Create the environment with 'human' render mode
env = gym.make('Pendulum-v1', render_mode='human')

# Run the visualization
final_reward = visualize_policy(agent, env)
print(f"Final episode reward: {final_reward}")

# Sources

Please check out the originl papers and our sources for more information!

https://arxiv.org/abs/1801.01290

https://arxiv.org/abs/1812.05905

https://lilianweng.github.io/posts/2018-04-08-policy-gradient/

https://spinningup.openai.com/en/latest/algorithms/sac.html

https://github.com/openai/spinningup/tree/master/spinup/algos/pytorch/sac

https://github.com/pranz24/pytorch-soft-actor-critic