In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
import numpy as np
import collections, random
import argparse

def get_parser():
    parser = argparse.ArgumentParser(description='Soft Actor Critic')
    parser.add_argument('--lr-pi', default=5e-4)
    parser.add_argument('--lr-q', default=1e-3)
    parser.add_argument('--lr-alpha', default=1e-3)
    parser.add_argument('--alpha_init', default=1e-2)
    parser.add_argument('--gamma', default=0.98)
    parser.add_argument('--batch-size', default=32)
    parser.add_argument('--buffer-size', default=5e4, type=int)
    parser.add_argument('--tau', default=1e-2)
    parser.add_argument('--target-entropy', default=-1.0)
    return parser

In [2]:
class ReplayBuffer():
    def __init__(self, buffer_size, device):
        self.buffer = collections.deque(maxlen=int(buffer_size))
        self.device = device
    
    def append(self, transition):
        self.buffer.append(transition)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)

        s_batch = torch.FloatTensor([t[0] for t in batch]).to(self.device)
        a_batch = torch.FloatTensor([[t[1]] for t in batch]).to(self.device)
        r_batch = torch.FloatTensor([[t[2]] for t in batch]).to(self.device)
        s2_batch = torch.FloatTensor([t[3] for t in batch]).to(self.device)
        done_batch = torch.FloatTensor([[0.0] if t[4] else [1.0] for t in batch]).to(self.device)

        return s_batch, a_batch, r_batch, s2_batch, done_batch
    
    @property
    def size(self):
        return len(self.buffer)

In [3]:
class Qnet(nn.Module):
    def __init__(self, n_states):
        super(Qnet, self).__init__()
        self.fc_state = nn.Linear(n_states, 64)
        self.fc_action = nn.Linear(1, 64)
        self.fc_concat = nn.Linear(64 * 2 , 32)
        self.fc_out = nn.Linear(32, 1)
    
    def forward(self, state, action):
        h1 = F.relu(self.fc_state(state))
        h2 = F.relu(self.fc_action(action))
        h_concat = F.relu(self.fc_concat(torch.cat([h1, h2], dim=1)))
        q = self.fc_out(h_concat)
        return q


class TwinQnet(nn.Module):
    def __init__(self, n_states):
        super(TwinQnet, self).__init__()
        self.Qnet1, self.Qnet2 = [Qnet(n_states)] * 2

    def forward(self, state, action):
        q1 = self.Qnet1(state, action)
        q2 = self.Qnet2(state, action)
        return q1, q2


class Critic(nn.Module):
    def __init__(self, **kwargs):
        super(Critic, self).__init__()
        for k, v in kwargs.items():
            setattr(self, k, v)

        self.current_twin = TwinQnet(self.n_states)
        self.target_twin = TwinQnet(self.n_states)
        self.hard_update()
        self.q1_optim = optim.Adam(self.current_twin.Qnet1.parameters(), lr=self.lr_q)
        self.q2_optim = optim.Adam(self.current_twin.Qnet2.parameters(), lr=self.lr_q)

    def forward(self, state, action, target=False):
        if target:
            return self.target_twin(state, action)
        return self.current_twin(state, action)

    def train_net(self, target, batch):
        s, a, _, _, _ = batch
        q1, q2 = self.forward(s, a)
        q1_loss = F.smooth_l1_loss(q1, target)
        q2_loss = F.smooth_l1_loss(q2, target)
        
        self.q1_optim.zero_grad()
        self.q2_optim.zero_grad()
        
        q1_loss.mean().backward()
        q2_loss.mean().backward()

        self.q1_optim.step()
        self.q2_optim.step()
        
    def soft_update(self):
        for t, s in zip(self.target_twin.parameters(), self.current_twin.parameters()):
            t.data.copy_(t.data * (1.0 - self.tau) + s.data * self.tau)

    def hard_update(self):
        self.target_twin.load_state_dict(self.current_twin.state_dict())



class Actor(nn.Module):
    def __init__(self, **kwargs):
        super(Actor, self).__init__()
        for k, v in kwargs.items():
            setattr(self, k, v)
        self.fc_state = nn.Linear(self.n_states, 128)
        self.fc_mu = nn.Linear(128, 1)
        self.fc_std = nn.Linear(128, 1)
        self.fc_out = nn.Linear(32, 1)
        self.actor_optim = optim.Adam(self.parameters(), lr=self.lr_pi)

        self.log_alpha = torch.tensor(np.log(self.alpha_init))
        self.log_alpha.requires_grad = True
        self.log_alpha_optim = optim.Adam([self.log_alpha], lr=self.lr_alpha)
    
    def forward(self, x):
        x = F.relu(self.fc_state(x))
        mu, std = self.fc_mu(x), F.softplus(self.fc_std(x))
        dist = Normal(mu, std)
        action = dist.rsample()
        log_prob = dist.log_prob(action)

        bounded_action = torch.tanh(action)
        bounded_log_prob = log_prob - torch.log(1-torch.tanh(action).pow(2) + 1e-7)
        return bounded_action, bounded_log_prob
    
    def train_net(self, critic, batch):
        s, _, _, _ , _ = batch
        a, log_prob = self.forward(s)

        q1_val, q2_val = critic(s, a)
        min_q = torch.min(q1_val, q2_val)

        actor_loss = self.log_alpha.exp() * log_prob - min_q
        self.actor_optim.zero_grad()
        actor_loss.mean().backward()
        self.actor_optim.step()

        self.log_alpha_optim.zero_grad()
        alpha_loss = - (self.log_alpha.exp() * (log_prob + self.target_entropy).detach()).mean()
        alpha_loss.backward()
        self.log_alpha_optim.step()


class SAC():
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)
        actor_configs = {k: kwargs[k] for k in ['n_states', 'lr_pi', 'lr_alpha', 'alpha_init', 'target_entropy']}
        critic_configs = {k: kwargs[k] for k in ['n_states', 'lr_q', 'tau']}
        buffer_configs = {k: kwargs[k] for k in ['buffer_size', 'device']}
        self.actor = Actor(**actor_configs).to(self.device)
        self.critic = Critic(**critic_configs).to(self.device)
        self.buffer = ReplayBuffer(**buffer_configs)

    
    @torch.no_grad()
    def calc_q_target(self, batch):
        s, a, r, s2, done = batch

        a2, log_prob = self.actor(s2)
        q1_val, q2_val = self.critic(s2, a2, target=True)
        min_q = torch.min(q1_val, q2_val)
        entropy = -self.actor.log_alpha.exp() * log_prob
        
        target = r + self.gamma * done * (min_q + entropy)

        return target

    def learn(self):
        batch = self.buffer.sample(self.batch_size)
        td_target = self.calc_q_target(batch)

        self.critic.train_net(td_target, batch)
        self.actor.train_net(self.critic, batch)
        self.critic.soft_update()

In [4]:
def main():
    
    parser = get_parser()
    args = parser.parse_args([])
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    env = gym.make('Pendulum-v0')
    args.n_states = env.observation_space.shape[0]

    configs = dict(vars(args))
    agent = SAC(**configs)

    score = 0.0
    interval = 20
    action_limit = env.action_space.high[0]

    for ep in range(10000):
        s = env.reset()
        done = False

        while not done:
            a, log_prob = agent.actor(torch.tensor(s, dtype=torch.float, device=agent.device))
            s2, r, done, info = env.step([action_limit * a.item()])
            agent.buffer.append((s, a.item(), r/10.0, s2, done))
            score += r
            s = s2
        
        if agent.buffer.size > 1e3:
            for i in range(20):
                agent.learn()
        
        if ep % interval ==0 and ep != 0:
            avg_score = score/interval
            alpha = agent.actor.log_alpha.exp()
            print(f'# of ep : {ep:04d} | Avg. score : {avg_score:.1f} | alpha : {alpha:.4f}')
            score = 0.0
    
    env.close()

if __name__ == '__main__':
    main()

# of ep : 0020 | Avg. score : -1535.2 | alpha : 0.0078
# of ep : 0040 | Avg. score : -1671.2 | alpha : 0.0059
# of ep : 0060 | Avg. score : -1496.6 | alpha : 0.0063
# of ep : 0080 | Avg. score : -1317.6 | alpha : 0.0076
# of ep : 0100 | Avg. score : -1332.9 | alpha : 0.0103
# of ep : 0120 | Avg. score : -1304.5 | alpha : 0.0141
# of ep : 0140 | Avg. score : -1227.1 | alpha : 0.0189
# of ep : 0160 | Avg. score : -1205.4 | alpha : 0.0182
# of ep : 0180 | Avg. score : -1073.3 | alpha : 0.0162
# of ep : 0200 | Avg. score : -1088.6 | alpha : 0.0161
# of ep : 0220 | Avg. score : -1083.8 | alpha : 0.0139
# of ep : 0240 | Avg. score : -907.7 | alpha : 0.0158
# of ep : 0260 | Avg. score : -928.6 | alpha : 0.0151
# of ep : 0280 | Avg. score : -945.4 | alpha : 0.0187
# of ep : 0300 | Avg. score : -799.5 | alpha : 0.0283
# of ep : 0320 | Avg. score : -752.0 | alpha : 0.0289
# of ep : 0340 | Avg. score : -822.8 | alpha : 0.0295
# of ep : 0360 | Avg. score : -965.3 | alpha : 0.0276
# of ep : 0380 | 