In [None]:
# https://github.com/vaishak2future/sac/blob/master/sac.ipynb

import math
import random

import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal

use_cuda = torch.cuda.is_available()
device   = torch.device("cuda" if use_cuda else "cpu")

In [None]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.posn = 0
    
    def push(self, state, action, reward, next_state, terminal):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.posn] = (state, action, reward, next_state, done)
        self.posn = (self.posn + 1) % self.posn
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, terminal = map(np.stack, zip(*batch))
        return state, action, reward, next_state, terminal
    
    def __len__(self):
        return len(self.buffer)
    
class NormalizedActions(gym.ActionWrapper):
    def action(self, action):
        low  = self.action_space.low
        high = self.action_space.high
        
        action = low + (action + 1.0) * 0.5 * (high - low)
        action = np.clip(action, low, high)
        
        return action

    def _reverse_action(self, action):
        low  = self.action_space.low
        high = self.action_space.high
        
        action = 2 * (action - low) / (high - low) - 1
        action = np.clip(action, low, high)
        
        return actions

In [None]:
class ValueNetwork(nn.Module):
    def __init__(self, num_states, hidden_size, init_w=1e-3):
        super(ValueNetwork, self).__init__()
        
        self.linear1 = nn.Linear(num_states, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
class SoftQNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=1e-3):
        super(SoftQNetwork, self).__init__()
        
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)
        
        self.linear3.weight.data.uniform_(-init_w, init_w)
        self.linear3.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x
    
class PolicyNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions, hidden_size, init_w=1e-3, log_std_min=-20, log_std_max=2):
        super(PolicyNetwork, self).__init__()
        
        self.log_std_min = log_std_min
        self.log_std_max = log_std_max
        
        self.linear1 = nn.Linear(num_inputs, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        
        self.mean_linear = nn.Linear(hidden_size, num_actions)
        self.mean_linear.weight.data.uniform_(-init_w, init_w)
        self.mean_linear.bias.data.uniform_(-init_w, init_w)
        
        self.log_std_linear = nn.Linear(hidden_size, num_actions)
        self.log_std_linear.weight.data.uniform_(-init_w, init_w)
        self.log_std_linear.bias.data.uniform_(-init_w, init_w)
        
    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
        
        return mean, log_std
    
    def evaluate(self, state, epsilon=1e-3):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        
        normal = Normal(0, 1)
        z = normal.sample()
        action = torch.tanh(mean + std*z.to(device))
        log_prob = Normal(mean, std).log_prob(mean+ std*z.to(device)) - torch.log(1 - action.pow(2) + epsilon)
        return action, log_p, z, mean, log_std
        
    
    def get_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        mean, log_std = self.forward(state)
        std = log_std.exp()
        
        normal = Normal(0, 1)
        z = normal.sample().to(device)
        action = torch.tanh(mean + std*z)
        
        action = action.cpu()
        return action[0]

def plot(frame_idx, rewards):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('frame %s. reward: %s' % (frame_idx, rewards[-1]))
    plt.plot(rewards)
    plt.show()

In [None]:
def update(batch_size, gamma=0.999, tau=1e-2,):
    
    state, action, reward, next_state, terminal = replay_buffer.sample(batch_size)

    state = torch.FloatTensor(state).to(device)
    next_state = torch.FloatTensor(next_state).to(device)
    action = torch.FloatTensor(action).to(device)
    reward = torch.FloatTensor(reward).unsqueeze(1).to(device)
    terminal = torch.FloatTensor(np.float32(terminal)).unsqueeze(1).to(device)

    predicted_q1 = q1_net(state, action)
    predicted_q2 = q2_net(state, action)
    predicted_value = value_net(state)
    new_action, log_p, epsilon, mean, log_std = policy_net.evaluate(state)

    # Train Q
    target_value = target_value_net(next_state)
    target_q_value = reward + (1 - terminal) * gamma * target_value
    q1_loss = q1_criterion(predicted_q1, target_q_value.detach())
    q2_loss = q2_criterion(predicted_q2, target_q_value.detach())

    q1_optim.zero_grad()
    q1_loss.backward()
    q1_optim.step()
    q2_optim.zero_grad()
    q2_loss.backward()
    q2_optim.step()
    
    # Train Value Function
    predicted_new_q_value = torch.min(q1_net(state, new_action), q2_net(state, new_action))
    target_value_fcn = predicted_new_q_value - log_prob
    value_loss = value_criterion(predicted_value, target_value_fcn.detach())
    
    value_optim.zero_grad()
    value_loss.backward()
    value_optim.step()
    
    # Train Policy
    policy_loss = (log_prob - predicted_new_q_value).mean()

    policy_optim.zero_grad()
    policy_loss.backward()
    policy_optim.step()
    
    for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
        target_param.data.copy_(
            target_param.data * (1.0 - tau) + param.data * tau
        )

In [None]:
env = NormalizedActions(gym.make("Pendulum-v0"))

num_actions = env.action_space.shape[0]
num_states  = env.observation_space.shape[0]
hidden_size = 256

value_net = ValueNetwork(num_states, hidden_size).to(device)
target_value_net = ValueNetwork(num_states, hidden_size).to(device)

q1_net = SoftQNetwork(num_states, num_actions, hidden_size).to(device)
q2_net = SoftQNetwork(num_states, num_actions, hidden_size).to(device)
policy_net = PolicyNetwork(state_dim, action_dim, hidden_size).to(device)

for target_param, param in zip(target_value_net.parameters(), value_net.parameters()):
    target_param.data.copy_(param.data)
    
value_criterion  = nn.MSELoss()
q1_criterion = nn.MSELoss()
q2_criterion = nn.MSELoss()

value_lr  = 1e-3
soft_q_lr = 1e-3
policy_lr = 1e-3

value_optimizer  = optim.Adam(value_net.parameters(), lr=value_lr)
q1_optim = optim.Adam(q1_net.parameters(), lr=soft_q_lr)
q2_optim = optim.Adam(q2_net.parameters(), lr=soft_q_lr)
policy_optimizer = optim.Adam(policy_net.parameters(), lr=policy_lr)

replay_buffer_size = 1000000
replay_buffer = ReplayBuffer(replay_buffer_size)

In [None]:
# HYPER PARAMS

max_frames = 40000
max_steps = 500
frame_idx = 0
rewards = []
batch_size = 128

# TRAINING

while frame_idx < max_frames:
    state = env.reset()
    episode_reward = 0
    
    for step in range(max_steps):
        if frame_idx >1000:
            action = policy_net.get_action(state).detach()
            next_state, reward, terminal, _ = env.step(action.numpy())
        else:
            action = env.action_space.sample()
            next_state, reward, terminal, _ = env.step(action)
        
        replay_buffer.push(state, action, reward, next_state, terminal)
        
        state = next_state
        episode_reward += reward
        frame_idx += 1
        
        if len(replay_buffer) > batch_size:
            update(batch_size)
        
        if frame_idx % 1000 == 0:
            plot(frame_idx, rewards)
        
        if terminal:
            break
        
    rewards.append(episode_reward)