In [1]:
%matplotlib inline

import torch
from torch.optim import Adam
import torch.nn.functional as F
import gym
import numpy as np
import random
from collections import namedtuple
from model import ValueNetwork, QNetwork, GaussianPolicy, DeterministicPolicy

## Replay Memory
Replay memory stores the transitions that the agent observes, allowing us to reuse this data later. By sampling from it randomly, the transitions that build up a batch are decorrelated. It has been shown that this greatly stabilizes and improves the DQN training procedure.

For this, we're going to need two classses:

- Transition - a named tuple representing a single transition in our environment. It essentially maps (state, action) pairs to their (next_state, reward) result, with the state being the screen difference image as described later on.
- ReplayMemory - a cyclic buffer of bounded size that holds the transitions observed recently. It also implements a .sample() method for selecting a random batch of transitions for training.

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'mask'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        return Transition(*zip(*batch))

    def __len__(self):
        return len(self.memory)

## Utils: Soft_update and Hard_update

In [3]:
def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)

## SAC

The algorithm is a V-based method.

- First we have three net: $V(s; \theta_V)$, $Q(s, a; \theta_Q)$ and $\pi(a \vert s; \theta_\pi)$.

- We want $Q(s, a; \theta_Q) = Q_V = \sum_{s'} p(s' \vert s, a)\left(r(s, a, s') + \gamma V(s';\theta_V) \right)$;

  $$
  J(\theta_Q) = \mathbb{E}_{(s, a) \sim \mathcal{D}} \left\{\frac{1}{2} (\sum_{s'} p(s' \vert s, a)\left(r(s, a, s') + \gamma V(s';\theta_V)\right) - Q(s, a; \theta_Q) )^2 \right\};
  $$

- We want $\pi(a \vert s; \theta_\pi) = \pi^*_{Q, soft}(a \vert s) = \frac{Q(s, a; \theta_Q)}{\sum_{a'} Q(s, a'; \theta_Q)}$;

  $$
  \begin{align*}
  J(\theta_\pi) =& \mathbb{E}_{s \sim \mathcal{D}}\left\{D_{KL} \left(\pi(\cdot \vert s; \theta_\pi) \Vert \pi^*_{Q, soft}(s, \cdot) \right)\right\} \\
  =& \mathbb{E}_{s \sim \mathcal{D}, a \sim \pi(\cdot \vert s;\theta_\pi)} \left\{\log(\pi(a \vert s;\theta_\pi)) - log(\pi^*_{Q, soft}(a \vert s))\right\}\\
  =& \mathbb{E}_{s \sim \mathcal{D}, a \sim \pi(\cdot \vert s;\theta_\pi)} \left\{\log(\pi(a \vert s;\theta_\pi)) - \frac{1}{\alpha} Q(s, a;\theta_Q) + \log\left(\sum_a \exp\left\{\frac{1}{\alpha} Q(s, a; \theta_Q)\right\}\right)\right\}
  \end{align*}
  $$

  If we use Gaussian distribution in continuous action space:

  $$
  J(\theta_\pi) = \mathbb{E}_{s \sim \mathcal{D}, \epsilon \sim \mathcal{N}(0, 1)} \left\{\log(\pi(f(s; \epsilon, \theta_\pi) \vert s)) - \frac{1}{\alpha} Q(s, f(s; \epsilon, \theta_\pi);\theta_Q) + \log\left(\sum_a \exp\left\{\frac{1}{\alpha} Q(s, a; \theta_Q)\right\}\right)\right\}
  $$

  $$
  \nabla_{\theta_\pi} J(\theta_\pi) = \nabla_{\theta_\pi}\log(\pi(f(s; \epsilon, \theta_\pi) \vert s)) - \frac{1}{\alpha} \nabla_{\theta_\pi} f(s; \epsilon, \theta_\pi) \nabla_a Q(s,a; \theta_{Q})\vert_{a = f(s; \epsilon, \theta_\pi)}
  $$

- We want $V(s; \theta_V) = T^\pi_{soft} V(s; \theta_V) = \mathbb{E}_{a \sim \pi(\cdot \vert s; \theta_\pi)} \left[Q(s, a; \theta_Q)  - \alpha \log(\pi(a \vert s; \theta_\pi))  \right]$;

$$
\begin{align*}
J(\theta_V) =& \mathbb{E}_{s \sim \mathcal{D}} \left\{\frac{1}{2} \left(V(s;\theta_V) - \mathbb{E}_{a \sim \pi(\cdot \vert s; \theta_\pi)} \left[Q(s, a; \theta_Q)  - \alpha \log(\pi(a \vert s; \theta_\pi))  \right]\right)^2 \right\}
\end{align*}
$$

The paper use two Value-Net: $V(s; \theta_V)$ and $\hat V(s; \theta_{\hat V})$, and $\theta_{\hat V} = (1 - \tau) \theta_{\hat V} + \tau \theta_V$

In [4]:
class SAC(object):
    def __init__(self, num_inputs, action_space, args):
        self.gamma = args['gamma']
        self.tau = args['tau']
        self.alpha = args['alpha']

        self.policy_type = args['policy_type']
        self.target_update_interval = args['target_update_interval']
        
        self.device = torch.device("cuda" if args['cuda'] else "cpu")

        self.value = ValueNetwork(num_inputs, args['hidden_size']).to(self.device)
        self.value_optim = Adam(self.value.parameters(), lr=args['lr'])

        self.value_target = ValueNetwork(num_inputs, args['hidden_size']).to(self.device)
        hard_update(self.value_target, self.value)

        self.critic = QNetwork(num_inputs, action_space.shape[0], args['hidden_size']).to(self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args['lr'])


        if self.policy_type == "Gaussian":
            self.policy = GaussianPolicy(num_inputs, 
                                        action_space.shape[0],
                                        args['hidden_size'],
                                        action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args['lr'])
        else:
            self.alpha = 0 # ???
            self.policy = DeterministicPolicy(num_inputs, 
                                            action_space.shape[0],
                                            args['hidden_size'],
                                            action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args['lr'])

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]
    
    def update_parameters(self, memory, batch_size, updates):
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            q_target_batch = reward_batch + self.gamma * self.value_target(next_state_batch)

            
        q1_batch, q2_batch = self.critic(state_batch, action_batch)
        critic_loss = F.mse_loss(q1_batch, q_target_batch) + F.mse_loss(q2_batch, q_target_batch)

        a, log_p_a, _ = self.policy.sample(state_batch)
        q1_a, q2_a = self.critic(state_batch, a)
        min_q_a = torch.min(q1_a, q2_a)
        policy_loss = ((self.alpha * log_p_a) - min_q_a).mean()

        with torch.no_grad():
            v_target = min_q_a - self.alpha * log_p_a
        v = self.value(state_batch)
        value_loss = F.mse_loss(v, v_target)

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        self.value_optim.zero_grad()
        value_loss.backward()
        self.value_optim.step()

        if updates % self.target_update_interval == 0:
            soft_update(self.value_target, self.value, self.tau)
        
        return critic_loss + value_loss, policy_loss

## SAC2

- $Q(s, a; \theta_Q) = \mathbb{E}_{s, a, s' \sim \mathcal{D}} [r(s, a, s') + \gamma \mathbb{E}_{a' \sim \pi(\cdot \vert s'; \theta_\pi)}(Q(s', a'; \theta_{Q-target}) - \alpha \log\pi(a' \vert s'; \theta_\pi))]$;
- $\pi(a\vert s; \theta_\pi) = \frac{Q(s, a; \theta_Q)}{\sum_{a'} Q(s, a' ; \theta_Q)}$.

However we use Gaussian distribution in continuous action space:

  $$
  J(\theta_\pi) = \mathbb{E}_{s \sim \mathcal{D}, \epsilon \sim \mathcal{N}(0, 1)} \left\{\log(\pi(f(s; \epsilon, \theta_\pi) \vert s)) - \frac{1}{\alpha} Q(s, f(s; \epsilon, \theta_\pi);\theta_Q) + \log\left(\sum_a \exp\left\{\frac{1}{\alpha} Q(s, a; \theta_Q)\right\}\right)\right\}
  $$

  $$
  \nabla_{\theta_\pi} J(\theta_\pi) = \nabla_{\theta_\pi}\log(\pi(f(s; \epsilon, \theta_\pi) \vert s)) - \frac{1}{\alpha} \nabla_{\theta_\pi} f(s; \epsilon, \theta_\pi) \nabla_a Q(s,a; \theta_{Q})\vert_{a = f(s; \epsilon, \theta_\pi)}
  $$

In [5]:
class SAC2(object):
    def __init__(self, num_inputs, action_space, args):
        self.gamma = args['gamma']
        self.tau = args['tau']
        self.alpha = args['alpha']

        self.policy_type = args['policy_type']
        self.target_update_interval = args['target_update_interval']
        
        self.device = torch.device("cuda" if args['cuda'] else "cpu")

        self.critic = QNetwork(num_inputs, action_space.shape[0], args['hidden_size']).to(self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=args['lr'])

        self.critic_target = QNetwork(num_inputs, action_space.shape[0], args['hidden_size']).to(self.device)
        hard_update(self.critic_target, self.critic)

        if self.policy_type == "Gaussian":
            self.policy = GaussianPolicy(num_inputs, 
                                        action_space.shape[0],
                                        args['hidden_size'],
                                        action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args['lr'])
        else:
            self.alpha = 0 # ???
            self.policy = DeterministicPolicy(num_inputs, 
                                            action_space.shape[0],
                                            args['hidden_size'],
                                            action_space).to(self.device)
            self.policy_optim = Adam(self.policy.parameters(), lr=args['lr'])

    def select_action(self, state, eval=False):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        if eval == False:
            action, _, _ = self.policy.sample(state)
        else:
            _, _, action = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]
    
    def update_parameters(self, memory, batch_size, updates):
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            next_action_batch, log_p_next_action_batch, _ = self.policy.sample(next_state_batch)
            q1_next_target_batch, q2_next_target_batch = self.critic_target(next_state_batch, next_action_batch)
            min_q_next_target_batch = torch.min(q1_next_target_batch, q2_next_target_batch) - self.alpha * log_p_next_action_batch
            next_q_batch = reward_batch  + mask_batch * self.gamma * min_q_next_target_batch

        q1_batch, q2_batch = self.critic(state_batch, action_batch)
        critic_loss = F.mse_loss(q1_batch, next_q_batch) + F.mse_loss(q2_batch, next_q_batch)
        a, log_p_a, _ = self.policy.sample(state_batch)
        q1_a, q2_a = self.critic(state_batch, a)
        min_q_a = torch.min(q1_a, q2_a)
        policy_loss = ((self.alpha * log_p_a) - min_q_a).mean()

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if updates % self.target_update_interval == 0:
            soft_update(self.critic_target, self.critic, self.tau)
        
        return critic_loss, policy_loss

### Actor-Critic Training

In [6]:
def actor_critic_train(algorithm, args):
    env = gym.make(args['env_name'])
    torch.manual_seed(args['seed'])
    np.random.seed(args['seed'])
    env.seed(args['seed'])

    agent = algorithm(env.observation_space.shape[0], env.action_space, args)
    memory = ReplayMemory(args['replay_size'])

    total_numsteps = 0
    updates = 0

    for i_epsisode in range(1, args['num_steps']+1):
        episode_reward = 0
        episode_steps = 0
        done = False
        state = env.reset()
        critic_loss = 0
        actor_loss = 0
        while not done:
            if args['start_steps'] > total_numsteps:
                action = env.action_space.sample()
            else:
                action = agent.select_action(state)
            
            if len(memory) > args['batch_size']:
                for i in range(args['updates_per_step']):
                    critic_loss, actor_loss = agent.update_parameters(memory, args['batch_size'], updates)
                    updates += 1
            
            next_state, reward, done, _ = env.step(action)
            episode_steps += 1
            total_numsteps += 1
            episode_reward += reward

            mask = 1 if episode_steps == env._max_episode_steps else float(not done)

            memory.push(state, action, reward, next_state, mask)
            state = next_state
        
        yield total_numsteps, episode_reward, critic_loss, actor_loss

    env.close()

In [7]:
args = {
    'env_name'              : 'HalfCheetah-v2',
    'policy_type'           : 'Gaussian',
    'gamma'                 : 0.99,
    'tau'                   : 0.005,
    'lr'                    : 0.0003,
    'alpha'                 : 0.2,
    'seed'                  : 0,
    'batch_size'            : 100,
    'num_steps'             : 1000000,
    'hidden_size'           : 256,
    'updates_per_step'      : 1,
    'start_steps'           : 10000,
    'target_update_interval': 1,
    'replay_size'           : 1000000,
    'cuda'                  : True
}

In [8]:
for total_numsteps, episode_reward, critic_loss, actor_loss in actor_critic_train(SAC, args):
    print("Total_steps {:>10d}: EpReward = {:>15.6f}, Critic_Loss = {:>10.6f}, Actor_Loss = {:>10.6f}".format(
        total_numsteps, episode_reward, critic_loss, actor_loss
    ))

3.556136, Critic_Loss =   3.047202, Actor_Loss = -33.649033
Total_steps      33000: EpReward =       13.401999, Critic_Loss =   3.518537, Actor_Loss = -32.891869
Total_steps      34000: EpReward =      185.230455, Critic_Loss =   4.823693, Actor_Loss = -34.864418
Total_steps      35000: EpReward =      -72.859153, Critic_Loss =   3.153778, Actor_Loss = -35.856899
Total_steps      36000: EpReward =      767.623638, Critic_Loss =   3.291799, Actor_Loss = -34.900433
Total_steps      37000: EpReward =      858.033783, Critic_Loss =   2.781630, Actor_Loss = -35.462585
Total_steps      38000: EpReward =      -49.936063, Critic_Loss =   3.377226, Actor_Loss = -37.187977
Total_steps      39000: EpReward =      777.994755, Critic_Loss =   3.134436, Actor_Loss = -39.884617
Total_steps      40000: EpReward =       -9.701410, Critic_Loss =   3.997271, Actor_Loss = -40.041836
Total_steps      41000: EpReward =      164.914971, Critic_Loss =   3.440341, Actor_Loss = -38.187122
Total_steps      42000

KeyboardInterrupt: 