Code Reference: https://github.com/zhihanyang2022/pytorch-sac

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal, Independent
import numpy as np
import math
import random
import matplotlib.pyplot as plt
from mlagents_envs.environment import UnityEnvironment

In [3]:
if(torch.cuda.is_available()):
    device = torch.device("cuda")
    print(device, torch.cuda.get_device_name(0))
else:
    device= torch.device("cpu")
    print(device)

cuda NVIDIA GeForce RTX 3060


### Replay Buffer

In [4]:
from collections import namedtuple, deque

Transition = namedtuple('Transition', 's a r ns d')
Batch = namedtuple('Batch', 's a r ns d')

In [5]:
# example of namedtuple
User = namedtuple('User', 'name age id')
user = User('tester', '22', '464643123')
print(user)

User(name='tester', age='22', id='464643123')


In [6]:
class ReplayBuffer(object):

    def __init__(self, capacity: int):
        self.capacity = capacity
        self.memory = deque(maxlen=capacity)

    def push(self, transition: Transition) -> None:
        self.memory.appendleft(transition)

    def ready_for(self, batch_size: int) -> bool:
        if len(self.memory) >= batch_size:
            return True
        return False

    def sample(self, batch_size: int) -> Batch:
        experiences = random.sample(self.memory, batch_size)
        s  = torch.from_numpy(np.vstack([e.s for e in experiences if e is not None])).float().to(device)
        a  = torch.from_numpy(np.vstack([e.a for e in experiences if e is not None])).float().to(device)
        r  = torch.from_numpy(np.vstack([e.r for e in experiences if e is not None])).float().to(device)
        ns  = torch.from_numpy(np.vstack([e.ns for e in experiences if e is not None])).float().to(device)
        d  = torch.from_numpy(np.vstack([e.d for e in experiences if e is not None])).float().to(device)
        return Batch(s, a, r, ns, d)

In [7]:
# test buffer
b = ReplayBuffer(capacity=5)  
b.push(Transition(1,2, 3, 4, 5))
print(b.memory[0].a, b.memory[0].r)

2 3


### NN

In [8]:
def get_net(
        num_in:int,
        num_out:int,
        final_activation,  # e.g. nn.Tanh
        num_hidden_layers:int=5,
        num_neurons_per_hidden_layer:int=64
    ) -> nn.Sequential:

    layers = []

    layers.extend([
        nn.Linear(num_in, num_neurons_per_hidden_layer),
        nn.ReLU(),
    ])

    for _ in range(num_hidden_layers):
        layers.extend([
            nn.Linear(num_neurons_per_hidden_layer, num_neurons_per_hidden_layer),
            nn.ReLU(),
        ])

    layers.append(nn.Linear(num_neurons_per_hidden_layer, num_out))

    if final_activation is not None:
        layers.append(final_activation)

    return nn.Sequential(*layers)

In [9]:
class NormalPolicyNet(nn.Module):

    """Outputs a distribution with parameters learnable by gradient descent."""

    def __init__(self, input_dim, action_dim):
        super(NormalPolicyNet, self).__init__()
        self.shared_net   = get_net(num_in=input_dim, num_out=64, final_activation=nn.ReLU())
        self.means_net    = nn.Linear(64, action_dim)
        self.log_stds_net = nn.Linear(64, action_dim)

    def forward(self, states: torch.tensor):

        out = self.shared_net(states)
        means, log_stds = self.means_net(out), self.log_stds_net(out)

        # the gradient of computing log_stds first and then using torch.exp
        # is much more well-behaved then computing stds directly using nn.Softplus()
        # ref: https://github.com/openai/spinningup/blob/master/spinup/algos/pytorch/sac/core.py#L26

        LOG_STD_MAX = 2
        LOG_STD_MIN = -20

        stds = torch.exp(torch.clamp(log_stds, LOG_STD_MIN, LOG_STD_MAX))

        return Independent(Normal(loc=means, scale=stds), reinterpreted_batch_ndims=1)

In [10]:
class QNet(nn.Module):

    """Has little quirks; just a wrapper so that I don't need to call concat many times"""

    def __init__(self, input_dim, action_dim):
        super(QNet, self).__init__()
        self.net = get_net(num_in=input_dim+action_dim, num_out=1, final_activation=None)

    def forward(self, states: torch.tensor, actions: torch.tensor):
        return self.net(torch.cat([states, actions], dim=1))

# Agent

In [11]:
class Agent:

    def __init__(self, input_dim, action_dim):

        self.Normal   = NormalPolicyNet(input_dim=input_dim, action_dim=action_dim).to(device)
        self.Normal_optimizer = optim.Adam(self.Normal.parameters(), lr=1e-3)

        self.Q1       = QNet(input_dim=input_dim, action_dim=action_dim).to(device)
        self.Q1_targ  = QNet(input_dim=input_dim, action_dim=action_dim).to(device)
        self.Q1_targ.load_state_dict(self.Q1.state_dict())
        self.Q1_optimizer = optim.Adam(self.Q1.parameters(), lr=1e-3)

        self.Q2       = QNet(input_dim=input_dim, action_dim=action_dim).to(device)
        self.Q2_targ  = QNet(input_dim=input_dim, action_dim=action_dim).to(device)
        self.Q2_targ.load_state_dict(self.Q2.state_dict())
        self.Q2_optimizer = optim.Adam(self.Q2.parameters(), lr=1e-3)

        self.gamma = 0.99
        self.alpha = 0.1
        self.polyak = 0.995

    # ==================================================================================================================
    # Helper methods (it is generally not my style of using helper methods but here they improve readability)
    # ==================================================================================================================

    def min_i_12(self, a: torch.tensor, b: torch.tensor) -> torch.tensor:
        return torch.min(a, b)

    def sample_action_and_compute_log_pi(self, state, use_reparametrization_trick):
        mu_given_s = self.Normal(state)  # in paper, mu represents the normal distribution
        # in paper, u represents the un-squashed action; nu stands for next u's
        # actually, we can just use reparametrization trick in both Step 12 and 14, but it might be good to separate
        # the two cases for pedagogical purposes, i.e., using reparametrization trick is a must in Step 14
        u = mu_given_s.rsample() if use_reparametrization_trick else mu_given_s.sample()
        a = torch.tanh(u)
        # the following line of code is not numerically stable:
        # log_pi_a_given_s = mu_given_s.log_prob(u) - torch.sum(torch.log(1 - torch.tanh(u) ** 2), dim=1)
        # ref: https://github.com/vitchyr/rlkit/blob/0073d73235d7b4265cd9abe1683b30786d863ffe/rlkit/torch/distributions.py#L358
        # ref: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L73
        log_pi_a_given_s = mu_given_s.log_prob(u) - (2 * (np.log(2) - u - F.softplus(-2 * u))).sum(dim=1)
        return a, log_pi_a_given_s

    def clip_gradient(self, net: nn.Module) -> None:
        for param in net.parameters():
            param.grad.data.clamp_(-1, 1)

    def polyak_update(self, old_net: nn.Module, new_net: nn.Module) -> None:
        for old_param, new_param in zip(old_net.parameters(), new_net.parameters()):
            old_param.data.copy_(old_param.data * self.polyak + new_param.data * (1 - self.polyak))

    # ==================================================================================================================
    # Methods for learning
    # ==================================================================================================================

    def update_networks(self, b: Batch) -> None:

        # ========================================
        # Step 12: calculating targets
        # ========================================

        with torch.no_grad():

            na, log_pi_na_given_ns = self.sample_action_and_compute_log_pi(b.ns, use_reparametrization_trick=False)
            targets = b.r + self.gamma * (1 - b.d) * \
                      (self.min_i_12(self.Q1_targ(b.ns, na), self.Q2_targ(b.ns, na)) - self.alpha * log_pi_na_given_ns)

        # ========================================
        # Step 13: learning the Q functions
        # ========================================

        Q1_predictions = self.Q1(b.s, b.a)
        Q1_loss = torch.mean((Q1_predictions - targets) ** 2)

        self.Q1_optimizer.zero_grad()
        Q1_loss.backward()
        self.clip_gradient(net=self.Q1)
        self.Q1_optimizer.step()

        Q2_predictions = self.Q2(b.s, b.a)
        Q2_loss = torch.mean((Q2_predictions - targets) ** 2)

        self.Q2_optimizer.zero_grad()
        Q2_loss.backward()
        self.clip_gradient(net=self.Q2)
        self.Q2_optimizer.step()

        # ========================================
        # Step 14: learning the policy
        # ========================================

        for param in self.Q1.parameters():
            param.requires_grad = False
        for param in self.Q2.parameters():
            param.requires_grad = False

        a, log_pi_a_given_s = self.sample_action_and_compute_log_pi(b.s, use_reparametrization_trick=True)
        policy_loss = - torch.mean(self.min_i_12(self.Q1(b.s, a), self.Q2(b.s, a)) - self.alpha * log_pi_a_given_s)

        self.Normal_optimizer.zero_grad()
        policy_loss.backward()
        self.clip_gradient(net=self.Normal)
        self.Normal_optimizer.step()

        for param in self.Q1.parameters():
            param.requires_grad = True
        for param in self.Q2.parameters():
            param.requires_grad = True

        # ========================================
        # Step 15: update target networks
        # ========================================

        with torch.no_grad():
            self.polyak_update(old_net=self.Q1_targ, new_net=self.Q1)
            self.polyak_update(old_net=self.Q2_targ, new_net=self.Q2)

    def act(self, state)-> np.array:
        # state: torch.FloatTensor(s).to(device)
        action, _ = self.sample_action_and_compute_log_pi(state, use_reparametrization_trick=False)
        return action.detach().cpu().numpy()  # no need to detach first because we are not using the reparametrization trick

    def save_actor(self, save_dir: str, filename: str) -> None:
        os.makedirs(save_dir, exist_ok=True)
        torch.save(self.Normal.state_dict(), os.path.join(save_dir, filename))

    def load_actor(self, save_dir: str, filename: str) -> None:
        self.Normal.load_state_dict(torch.load(os.path.join(save_dir, filename)))

### Instinate NN

In [103]:
N_STATES  = 243
N_ACTIONS =39

N_AGENTS = 3  
hidden_units = 256

In [104]:
buf = ReplayBuffer(capacity=int(1e3))   #2M in Walker.yaml

In [105]:
agent = Agent(
    input_dim=N_STATES,
    action_dim=N_ACTIONS
)

In [106]:
batch_size = 64 #1024 in Walker.yaml

# Step-by-step test of the training loop

In [107]:
env = UnityEnvironment(file_name= None, base_port=5004)

In [108]:
env.reset()
behaviorNames = list(env.behavior_specs.keys())
behaviorName = behaviorNames[0]
print(behaviorName)

Walker?team=0


In [109]:
DecisionSteps, TerminalSteps = env.get_steps(behaviorName)

In [110]:
s = DecisionSteps.obs[0]
print(s.shape)

(3, 243)


In [111]:
# ==================================================
# getting the tuple (s, a, r, s', done)
# ==================================================

# action = param.act(states)
# param.act calls sample_action_and_compute_log_pi(...) to calculate action

# step-by-step run sample_action_and_compute_log_pi
mu_given_s = agent.Normal(torch.FloatTensor(s).to(device)) 
print(mu_given_s)

Independent(Normal(loc: torch.Size([3, 39]), scale: torch.Size([3, 39])), 1)


In [112]:
u = mu_given_s.rsample() 
print(u.shape)

torch.Size([3, 39])


In [113]:
a = torch.tanh(u)
print(a.shape)

torch.Size([3, 39])


In [114]:
log_pi_a_given_s = mu_given_s.log_prob(u) - (2 * (np.log(2) - u - F.softplus(-2 * u))).sum(dim=1)
print(log_pi_a_given_s)

tensor([-25.4253, -23.9051, -26.7177], device='cuda:0', grad_fn=<SubBackward0>)


In [115]:
print(a.detach().cpu().numpy().shape)

(3, 39)


In [116]:
# finish step-by-step run sample_action_and_compute_log_pi, try call the function directly
a, _ = agent.sample_action_and_compute_log_pi(torch.FloatTensor(s).to(device), use_reparametrization_trick=False)
a = a.detach().cpu().numpy()
print(a.shape)

(3, 39)


In [117]:
a = agent.act(torch.FloatTensor(s).to(device))
print(a.shape)

(3, 39)


In [118]:
from mlagents_envs.base_env import ActionTuple

In [119]:
ActionTuple(np.array([[1.0,0.0]], dtype=np.float32))

<mlagents_envs.base_env.ActionTuple at 0x2058df47670>

In [120]:
a = ActionTuple(np.array(a, dtype=np.float32))

In [121]:
env.set_actions(behaviorName, a)
env.step()

In [122]:
# next_obs, reward, done, _ = env.step(action)
NextDecisionSteps, NextTerminalSteps = env.get_steps(behaviorName)
ns = NextDecisionSteps.obs[0]
reward = NextDecisionSteps.reward
reward = np.expand_dims(reward, axis=1)
done = np.array([[0]]*N_AGENTS ) 
print(ns.shape, ', ', reward.shape, ', ', done.shape)

(3, 243) ,  (3, 1) ,  (3, 1)


In [123]:
print(reward, "\n", np.mean(reward))

[[0.]
 [0.]
 [0.]] 
 0.0


In [124]:
a

<mlagents_envs.base_env.ActionTuple at 0x205b32f0ca0>

In [125]:
# from ActionTuple to np array
print(a._continuous.shape)

(3, 39)


In [126]:
# ==================================================
# storing it to the buffer
# ==================================================

#buf.push(Transition(obs, action, reward, next_obs, done))
buf.push(Transition(s, a._continuous, reward, ns, done))

In [127]:
print(buf.memory[0].a.shape, buf.memory[0].r.shape)

(3, 39) (3, 1)


In [128]:
# run a loop to fill the buffer with batch_size so we can update NN
DecisionSteps, TerminalSteps = env.get_steps(behaviorName)
for i in range(2*batch_size):
    s = DecisionSteps.obs[0]
    a = agent.act(torch.FloatTensor(s).to(device))
    a = ActionTuple(np.array(a, dtype=np.float32))
    env.set_actions(behaviorName, a)
    env.step()
    NextDecisionSteps, NextTerminalSteps = env.get_steps(behaviorName)
    
    #if next decision step misses some agents, then reset 
    if(len(NextDecisionSteps)!= N_AGENTS): 
        print(i, " reset training!")
        env.reset()    
        DecisionSteps, TerminalSteps = env.get_steps(behaviorName)
    else: 
        ns = NextDecisionSteps.obs[0]
        reward = NextDecisionSteps.reward
        reward = np.expand_dims(reward, axis=1)
        done = np.array([[0]]*N_AGENTS ) 
        buf.push(Transition(s, a._continuous, reward, ns, done))
        DecisionSteps, TerminalSteps = NextDecisionSteps, NextTerminalSteps

13  reset training!
32  reset training!
49  reset training!
67  reset training!
88  reset training!
103  reset training!


In [129]:
len(buf.memory)

123

In [130]:
# agent.update_networks(buf.sample(batch_size))

#step by step to run update_network

# buf.sample(batch_size)
#Transition = namedtuple('Transition', 's a r ns d')
#Batch = namedtuple('Batch', 's a r ns d')
experiences = random.sample(buf.memory, batch_size) 
print(batch_size, N_AGENTS, len(experiences))
print(experiences[0].s.shape, experiences[0].a.shape, experiences[0].r.shape, \
      experiences[0].ns.shape, experiences[0].d.shape)

64 3 64
(3, 243) (3, 39) (3, 1) (3, 243) (3, 1)


In [131]:
np.vstack([e.s for e in experiences if e is not None]).shape

(192, 243)

In [133]:
#Find all states in memory batch, append vertically, convert to torch with float values and send to CPU or GPU
bs  = np.vstack([e.s for e in experiences if e is not None])
ba  = np.vstack([e.a for e in experiences if e is not None])
br  = np.vstack([e.r for e in experiences if e is not None])
b_ns = np.vstack([e.ns for e in experiences if e is not None])
bd  = np.vstack([e.d for e in experiences if e is not None])
print(bs.shape, ba.shape, br.shape, b_ns.shape, bd.shape)

(192, 243) (192, 39) (192, 1) (192, 243) (192, 1)


In [134]:
bs  = torch.from_numpy(np.vstack([e.s for e in experiences if e is not None])).float().to(device)
print(bs.shape)

torch.Size([192, 243])


In [135]:
bs  = torch.from_numpy(np.vstack([e.s for e in experiences if e is not None])).float().to(device)
ba  = torch.from_numpy(np.vstack([e.a for e in experiences if e is not None])).float().to(device)
br  = torch.from_numpy(np.vstack([e.r for e in experiences if e is not None])).float().to(device)
b_ns  = torch.from_numpy(np.vstack([e.ns for e in experiences if e is not None])).float().to(device)
bd  = torch.from_numpy(np.vstack([e.d for e in experiences if e is not None])).float().to(device)
print(bs.shape, ba.shape, br.shape, b_ns.shape, bd.shape)

torch.Size([192, 243]) torch.Size([192, 39]) torch.Size([192, 1]) torch.Size([192, 243]) torch.Size([192, 1])


In [136]:
# finish step-by-step test of buf.sample
# directly call buf.sample(batch_size)
batch = buf.sample(batch_size)
print(batch.s.shape, batch.a.shape, batch.r.shape, batch.ns.shape, batch.d.shape)

torch.Size([192, 243]) torch.Size([192, 39]) torch.Size([192, 1]) torch.Size([192, 243]) torch.Size([192, 1])


In [137]:
# ========================================
# Step 12: calculating targets
# ========================================

with torch.no_grad():

    na, log_pi_na_given_ns = agent.sample_action_and_compute_log_pi(b_ns, use_reparametrization_trick=False)
    targets = br + agent.gamma * (1 - bd) * \
              (agent.min_i_12(agent.Q1_targ(b_ns, na), agent.Q2_targ(b_ns, na)) - agent.alpha * log_pi_na_given_ns)


In [138]:
# ========================================
# Step 13: learning the Q functions
# ========================================

Q1_predictions = agent.Q1(bs, ba)
Q1_loss = torch.mean((Q1_predictions - targets) ** 2)

In [139]:
print(Q1_loss, "\n", float(Q1_loss))

tensor(5.9385, device='cuda:0', grad_fn=<MeanBackward0>) 
 5.938521862030029


In [140]:
agent.Q1_optimizer.zero_grad()
Q1_loss.backward()
agent.clip_gradient(net=agent.Q1)
agent.Q1_optimizer.step()

In [141]:
Q2_predictions = agent.Q2(bs, ba)
Q2_loss = torch.mean((Q2_predictions - targets) ** 2)

In [142]:
print(Q2_loss, "\n", float(Q2_loss))

tensor(6.8861, device='cuda:0', grad_fn=<MeanBackward0>) 
 6.886084079742432


In [143]:
agent.Q2_optimizer.zero_grad()
Q2_loss.backward()
agent.clip_gradient(net=agent.Q2)
agent.Q2_optimizer.step()

In [144]:
# ========================================
# Step 14: learning the policy
# ========================================

for param in agent.Q1.parameters():
    param.requires_grad = False
for param in agent.Q2.parameters():
    param.requires_grad = False

In [145]:
a, log_pi_a_given_s = agent.sample_action_and_compute_log_pi(bs, use_reparametrization_trick=True)
policy_loss = - torch.mean(agent.min_i_12(agent.Q1(bs, a), agent.Q2(bs, a)) - agent.alpha * log_pi_a_given_s)

In [146]:
print(policy_loss, "\n", float(policy_loss))

tensor(-2.5262, device='cuda:0', grad_fn=<NegBackward0>) 
 -2.526196002960205


In [147]:
agent.Normal_optimizer.zero_grad()
policy_loss.backward()
agent.clip_gradient(net=agent.Normal)
agent.Normal_optimizer.step()

In [148]:
for param in agent.Q1.parameters():
    param.requires_grad = True
for param in agent.Q2.parameters():
    param.requires_grad = True

In [149]:
# ========================================
# Step 15: update target networks
# ========================================

with torch.no_grad():
    agent.polyak_update(old_net=agent.Q1_targ, new_net=agent.Q1)
    agent.polyak_update(old_net=agent.Q2_targ, new_net=agent.Q2)

In [150]:
# finish step-by-step test of agent.update_network
# directly call agent.update_networks(buf.sample(batch_size))
agent.update_networks(buf.sample(batch_size))

In [151]:
env.close()