In [1]:
# !pip install gym[box2d]
# !pip install numpy --upgrade

In [22]:
import os
import gym
import torch
import random
import itertools

import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.distributions import Categorical

from copy import deepcopy
from collections import deque

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [23]:
def set_seed(env, seed=0):
    os.environ["PYTHONHASHSEED"] = str(seed)
    env.seed(seed)
    env.action_space.seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    # torch.use_deterministic_algorithms(True)


class ReplayBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)
        
    def add(self, transition):
        self.buffer.append(transition)
    
    def sample(self, size):
        batch = random.sample(self.buffer, size)
        return list(zip(*batch))

In [24]:
class Actor(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=256):
        super().__init__()
                
        self.actor = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_size)
        )
        
    def forward(self, state, eval_mode=False, return_probs=False):
        logits = self.actor(state)
        probs = F.softmax(logits, dim=-1)
        
        policy_dist = Categorical(probs=probs)

        if eval_mode:
            action = torch.argmax(probs, dim=-1)
        else:
            action = policy_dist.sample()

        if return_probs:
            log_probs = F.log_softmax(logits, dim=-1)

            return action, probs, log_probs

        return action


class Critic(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=256):
        super().__init__()
        
        self.critic = nn.Sequential(
            nn.Linear(state_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_size)
        )
    
    def forward(self, state):        
        return self.critic(state)

In [25]:
class SoftActorCritic:
    def __init__(self, state_size, action_size, hidden_size, target_entropy_scale=0.5, gamma=0.99, tau=0.005, init_alpha=None, actor_lr=1e-4, critic_lr=1e-4, alpha_lr=1e-4):
        self.actor = Actor(state_size, action_size, hidden_size).to(DEVICE)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        
        self.critic1 = Critic(state_size, action_size, hidden_size).to(DEVICE)
        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=critic_lr)
        self.target_critic1 = deepcopy(self.critic1)
        
        self.critic2 = Critic(state_size, action_size, hidden_size).to(DEVICE)
        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=critic_lr)
        self.target_critic2 = deepcopy(self.critic2)

        for p in itertools.chain(self.target_critic1.parameters(), self.target_critic2.parameters()):
            p.requires_grad = False

        self.tau = tau
        self.gamma = gamma
        
        self.init_alpha = 0.0 if init_alpha is None else np.log(init_alpha)
        self.target_entropy = -np.log((1.0 / action_size)) * target_entropy_scale # * 0.98
  
        self.log_alpha = torch.tensor([self.init_alpha], dtype=torch.float32, device=DEVICE, requires_grad=True)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
        self.alpha = self.log_alpha.exp()

    def _soft_update(self, target, source):
        for tp, sp in zip(target.parameters(), source.parameters()):
            tp.data.copy_((1 - self.tau) * tp.data + self.tau * sp.data)
            
    def _actor_loss(self, state):
        action, action_probs, action_log_probs = self.actor(state, return_probs=True)

        Q_target = torch.min(
            self.critic1(state), 
            self.critic2(state)
        )
        # True expectation over all actions + sample estimate for expectation over states
        loss = ((self.alpha.detach() * action_log_probs - Q_target.detach()) * action_probs).sum(dim=1).mean()
        
        assert action_log_probs.shape == Q_target.shape == action_probs.shape

        return loss

    def _critic_loss(self, state, action, reward, next_state, done):
        with torch.no_grad():
            next_action, next_action_probs, next_action_log_probs = self.actor(next_state, return_probs=True)
            
            Q_min = torch.min(
                self.target_critic1(next_state),
                self.target_critic2(next_state)
            )
            # True expectation over actions to estimate V(s')
            Q_next = (next_action_probs * (Q_min - self.alpha * next_action_log_probs)).sum(dim=1)
            Q_target = reward + self.gamma * (1 - done) * Q_next

            assert Q_next.shape == reward.shape
            assert next_action_probs.shape == Q_min.shape == next_action_log_probs.shape
        
        # NOTE: gather need (batch_size, 1) action shape, not (batch_size,)
        Q1 = self.critic1(state).gather(1, action.reshape(-1, 1).long()).view(-1)
        Q2 = self.critic2(state).gather(1, action.reshape(-1, 1).long()).view(-1)
        
        loss = F.mse_loss(Q1, Q_target) + F.mse_loss(Q2, Q_target)

        assert Q1.shape == Q_target.shape and Q2.shape == Q_target.shape
        
        return loss
        
    def _alpha_loss(self, state):
        with torch.no_grad():
            action, action_probs, action_log_probs = self.actor(state, return_probs=True)
            # https://github.com/yining043/SAC-discrete/issues/2#event-3685116634
            action_log_probs_exp = (action_log_probs * action_probs).sum(dim=1)

        loss = (-self.log_alpha * (action_log_probs_exp + self.target_entropy)).mean()

        return loss

    def update(self, batch):
        state, action, reward, next_state, done = batch
        
        state = torch.tensor(state, device=DEVICE, dtype=torch.float32)
        action = torch.tensor(action, device=DEVICE, dtype=torch.float32)
        reward = torch.tensor(reward, device=DEVICE, dtype=torch.float32)
        next_state = torch.tensor(next_state, device=DEVICE, dtype=torch.float32)
        done = torch.tensor(done, device=DEVICE, dtype=torch.float32)
        
        # Critic1 & Critic2 update
        critic_losses = self._critic_loss(state, action, reward, next_state, done)

        self.critic1_optimizer.zero_grad()
        self.critic2_optimizer.zero_grad()
        critic_losses.backward()
        self.critic1_optimizer.step()
        self.critic2_optimizer.step()
        
        # Actor update    
        actor_loss = self._actor_loss(state)
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # Alpha update
        alpha_loss = self._alpha_loss(state)

        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        self.alpha = self.log_alpha.exp()

        #  Target networks soft update
        with torch.no_grad():
            self._soft_update(self.target_critic1, self.critic1)
            self._soft_update(self.target_critic2, self.critic2)

    def act(self, state, eval_mode=False):
        with torch.no_grad():
            state = torch.tensor(state, device=DEVICE, dtype=torch.float32)
            action = self.actor(state, eval_mode=eval_mode).cpu().numpy().item()
        return action
    
    def save(self, name):
        torch.save(self.actor.state_dict(), f"{name}.pt")

In [26]:
def evaluate_policy(env_name, agent, seed, episodes=5):
    env = gym.make(env_name)
    
    set_seed(env, seed)
    
    returns = []
    for _ in range(episodes):
        done = False
        state = env.reset()
        total_reward = 0.
        
        while not done:
            state, reward, done, _ = env.step(agent.act(state, eval_mode=True))
            total_reward += reward
        returns.append(total_reward)
    
    return np.mean(returns), np.std(returns)


def train(env_name, model, seed=0, timesteps=500_000, start_steps=10_000, start_train=1000, 
          buffer_size=100_000, batch_size=512, test_episodes=10, test_every=5000, update_every=10):
    print("Training on: ", DEVICE)
    
    env = gym.make(env_name)
    set_seed(env, seed)
    
    buffer = ReplayBuffer(size=buffer_size)
    best_reward = -np.inf
    
    means, stds = [], []
    
    done, state = False, env.reset()
    
    for t in range(timesteps):
        if done:
            done, state = False, env.reset()
    
        if t > start_train:
            action = model.act(state)
        else:
            action = env.action_space.sample()

        next_state, reward, done, _ = env.step(action)        
        buffer.add((state, action, reward, next_state, done))
    
        state = next_state
        
        if t > start_train:
            if t % update_every == 0:
                for _ in range(update_every):
                    batch = buffer.sample(batch_size)
                    model.update(batch)
            
            if t % test_every == 0 or t == timesteps - 1:
                mean, std = evaluate_policy(env_name, model, seed=seed, episodes=test_episodes)
                print(f"Step: {t + 1}, Reward mean: {mean}, Reward std: {std}, Alpha: {model.alpha.detach().cpu().item()}")
                
                if mean > best_reward:
                    best_reward = mean
                    model.save(f"best_agent")
                
                model.save(f"last_agent")
    
                means.append(mean)
                stds.append(std)
    
    return np.array(means), np.array(stds)

In [31]:
config = {
    "agent": {
        "state_size": 8,
        "action_size": 4,
        "hidden_size": 256,
        "gamma": 0.99,
        "tau": 0.001,
        "target_entropy_scale": 0.5, # 0.5
        "actor_lr": 2e-4,
        "critic_lr": 5e-4,
        "alpha_lr": 1e-5 # 1e-5
    },
    "trainer": {
        "seed": 0,
        "timesteps": 500_000,
        "start_train": 10_000,
        "buffer_size": 200_000, # better than int(1e6)
        "batch_size": 128,
        "test_episodes": 10,
        "test_every": 5_000,
        "update_every": 16 # 16
    }
}

model = SoftActorCritic(**config["agent"])
mean, std = train("LunarLander-v2", model, **config["trainer"])

Training on:  cpu
Step: 15001, Reward mean: -133.17716541093182, Reward std: 112.7897131906716, Alpha: 0.9701305627822876
Step: 20001, Reward mean: -306.56213472737164, Reward std: 101.9606333482044, Alpha: 0.9440081715583801
Step: 25001, Reward mean: -243.99490862548282, Reward std: 80.72378346974962, Alpha: 0.9334538578987122
Step: 30001, Reward mean: -320.6924113084025, Reward std: 76.14900610805385, Alpha: 0.9224791526794434
Step: 35001, Reward mean: -249.19914899674717, Reward std: 104.39374501747429, Alpha: 0.8974756598472595
Step: 40001, Reward mean: -33.65693273371791, Reward std: 33.649856075483804, Alpha: 0.8644418120384216
Step: 45001, Reward mean: -145.39037736904862, Reward std: 195.11835006181573, Alpha: 0.8276879787445068
Step: 50001, Reward mean: -6.984071233585493, Reward std: 98.56343890939584, Alpha: 0.7903392314910889
Step: 55001, Reward mean: -38.3534994829545, Reward std: 34.71027722539631, Alpha: 0.7543269991874695
Step: 60001, Reward mean: -18.367844848200754, R

Step: 410001, Reward mean: 263.76928066978905, Reward std: 20.09837264035266, Alpha: 0.24392205476760864
Step: 415001, Reward mean: 258.87617431820644, Reward std: 18.829408812404335, Alpha: 0.24353379011154175
Step: 420001, Reward mean: 254.00089449362773, Reward std: 13.806342092321147, Alpha: 0.24307413399219513
Step: 425001, Reward mean: 271.4616880709311, Reward std: 20.109551818209056, Alpha: 0.24234504997730255
Step: 430001, Reward mean: 260.36916988678433, Reward std: 24.153386429750697, Alpha: 0.24122145771980286
Step: 435001, Reward mean: 248.5031923294765, Reward std: 18.226278315623222, Alpha: 0.24139465391635895
Step: 440001, Reward mean: 252.07126946737156, Reward std: 14.671098940396414, Alpha: 0.23885147273540497
Step: 445001, Reward mean: 271.05518137527264, Reward std: 19.43150894518037, Alpha: 0.23836444318294525
Step: 450001, Reward mean: 265.1309925113035, Reward std: 17.420395767192723, Alpha: 0.23740257322788239
Step: 455001, Reward mean: 261.80858486030485, Rewa

In [32]:
config1 = {
    "agent": {
        "state_size": 8,
        "action_size": 4,
        "hidden_size": 32,
        "gamma": 0.99,
        "tau": 0.001,
        "target_entropy_scale": 0.5, # 0.5
        "actor_lr": 2e-4,
        "critic_lr": 5e-4,
        "alpha_lr": 1e-5 # 1e-5
    },
    "trainer": {
        "seed": 0,
        "timesteps": 500_000,
        "start_train": 10_000,
        "buffer_size": 200_000, # better than int(1e6)
        "batch_size": 128,
        "test_episodes": 10,
        "test_every": 5_000,
        "update_every": 16 # 16
    }
}

model1 = SoftActorCritic(**config1["agent"])
mean1, std1 = train("LunarLander-v2", model1, **config1["trainer"])

Training on:  cpu
Step: 15001, Reward mean: -171.79445337753612, Reward std: 21.144142598497574, Alpha: 0.9613027572631836
Step: 20001, Reward mean: -136.55867166895368, Reward std: 22.70925298460754, Alpha: 0.9203715920448303
Step: 25001, Reward mean: -118.87268968611443, Reward std: 22.955041044022288, Alpha: 0.8770214319229126
Step: 30001, Reward mean: -116.2477305722422, Reward std: 22.043151304277806, Alpha: 0.8369368314743042
Step: 35001, Reward mean: -74.91471377896605, Reward std: 23.09738939391834, Alpha: 0.8026368021965027
Step: 40001, Reward mean: -23.8266696096837, Reward std: 24.081517606353543, Alpha: 0.7704234719276428
Step: 45001, Reward mean: -0.9064542077316127, Reward std: 21.95486304383988, Alpha: 0.7393550872802734
Step: 50001, Reward mean: -2.056823901816847, Reward std: 22.640738533302873, Alpha: 0.710578978061676
Step: 55001, Reward mean: -11.75039547137502, Reward std: 22.638271968221492, Alpha: 0.6840417385101318
Step: 60001, Reward mean: 44.59825678978723, Re

Step: 410001, Reward mean: 179.21532990048104, Reward std: 81.82498783850546, Alpha: 0.20782804489135742
Step: 415001, Reward mean: 127.37628775506582, Reward std: 101.07372094012075, Alpha: 0.20849010348320007
Step: 420001, Reward mean: 167.2306707739204, Reward std: 103.1975844256442, Alpha: 0.20752054452896118
Step: 425001, Reward mean: 115.64750333224443, Reward std: 100.91810013838631, Alpha: 0.20657847821712494
Step: 430001, Reward mean: 94.16646484858792, Reward std: 110.00212085965082, Alpha: 0.20775116980075836
Step: 435001, Reward mean: 158.2136428087138, Reward std: 91.28451668858347, Alpha: 0.20719581842422485
Step: 440001, Reward mean: 48.581487965619715, Reward std: 119.5568279782198, Alpha: 0.20697815716266632
Step: 445001, Reward mean: 137.59563185768798, Reward std: 86.91682977845039, Alpha: 0.20851373672485352
Step: 450001, Reward mean: 147.41157018239574, Reward std: 106.82272611396827, Alpha: 0.20984292030334473
Step: 455001, Reward mean: 138.11774729979345, Reward 