In [1]:
import gym
import torch
from torch import nn
from collections import namedtuple, deque
import itertools
from copy import deepcopy
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import tqdm

  import distutils.spawn


In [2]:
RENDER = False

In [3]:
ALPHA = 0.3  # How much we value entropy / exploration, increasing this will increase exploration.
GAMMA = 0.9  # How much we value future rewards.
TAU = 0.01  # How much q_target is updated when polyak averaging (step 15).
POLICY_LR = 0.001  # Policy learning rate.
Q_LR = 0.001  # Q learning rate.

In [4]:
env = gym.make("LunarLander-v2")

In [5]:
SARS = namedtuple('SARS', 'state, action, reward, next_state, t, failed, limit')

In [6]:
softmax = nn.Softmax(dim=0)
input = torch.tensor([1, 2, 3], dtype=float)
display(input)
output = softmax(input)
display(output)
sum(output)

tensor([1., 2., 3.], dtype=torch.float64)

tensor([0.0900, 0.2447, 0.6652], dtype=torch.float64)

tensor(1.0000, dtype=torch.float64)

In [7]:
softmax = nn.Softmax(dim=1)
input = torch.tensor([[1, 2, 3], [1, 2, 3], [3, 3, 3]], dtype=float)
display(input)
output = softmax(input)
display(output)
sum(output)

tensor([[1., 2., 3.],
        [1., 2., 3.],
        [3., 3., 3.]], dtype=torch.float64)

tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652],
        [0.3333, 0.3333, 0.3333]], dtype=torch.float64)

tensor([0.5134, 0.8228, 1.6638], dtype=torch.float64)

In [8]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, 900),
            nn.ReLU(),
            nn.Linear(900, 300),
            nn.ReLU(),
            nn.Linear(300, output_size)
        )

    def forward(self, x):
        nn_out = self.linear_relu_stack(x)
        return nn.Softmax(dim=1)(nn_out)

    def __call__(self, x):
        raise RuntimeError("Use forward")

In [9]:
policy_network = PolicyNetwork(4, 2)
policy_network

PolicyNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=4, out_features=900, bias=True)
    (1): ReLU()
    (2): Linear(in_features=900, out_features=300, bias=True)
    (3): ReLU()
    (4): Linear(in_features=300, out_features=2, bias=True)
  )
)

In [10]:
torch.cuda.is_available()

True

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [12]:
mock_states = torch.rand(5, 4)
mock_states

tensor([[0.2641, 0.3076, 0.7351, 0.2398],
        [0.2376, 0.7746, 0.4289, 0.3967],
        [0.5034, 0.5841, 0.0033, 0.4338],
        [0.6472, 0.5605, 0.2108, 0.8040],
        [0.0029, 0.7294, 0.5659, 0.0165]])

In [13]:
policy_network.forward(mock_states)

tensor([[0.5258, 0.4742],
        [0.5204, 0.4796],
        [0.5303, 0.4697],
        [0.5291, 0.4709],
        [0.5199, 0.4801]], grad_fn=<SoftmaxBackward0>)

In [14]:
class QNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, 900),
            nn.ReLU(),
            nn.Linear(900, 300),
            nn.ReLU(),
            nn.Linear(300, output_size)
        )

    def forward(self, x):
        nn_out = self.linear_relu_stack(x)
        return nn_out
    
    def __call__(self, x):
        raise RuntimeError("Use forward")

In [15]:
q_network = QNetwork(4, 2)
q_network

QNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=4, out_features=900, bias=True)
    (1): ReLU()
    (2): Linear(in_features=900, out_features=300, bias=True)
    (3): ReLU()
    (4): Linear(in_features=300, out_features=2, bias=True)
  )
)

In [16]:
q_network.forward(mock_states)

tensor([[-0.0003, -0.0840],
        [ 0.0490, -0.0823],
        [ 0.0795, -0.0553],
        [ 0.0684, -0.0540],
        [ 0.0477, -0.1139]], grad_fn=<AddmmBackward0>)

In [17]:
def run_episode(action_f, step_f, env, policy, fail_at_limit=False):
    episode_reward = 0
    s = env.reset()
    for t in itertools.count(start=1):
        a = action_f(policy, s)
        next_state, reward, failed, info = env.step(a)
        episode_reward += reward
        assert t <= env._max_episode_steps
        limit = t == env._max_episode_steps
        if limit and not fail_at_limit:
            failed = False
        assert fail_at_limit or not (limit and failed)
        step_f(s, a, reward, next_state, t, failed, limit)
        if failed or limit:
            break
        s = next_state
    return episode_reward

In [18]:
class Policy:
    def __init__(self, env_state_size, env_action_space_size):
        self.policy_network = PolicyNetwork(env_state_size, env_action_space_size)
        self.q1_network = QNetwork(env_state_size, env_action_space_size)
        self.q2_network = QNetwork(env_state_size, env_action_space_size)
        self.q1_target_network = deepcopy(self.q1_network)
        self.q2_target_network = deepcopy(self.q2_network)
        self.reset_optimizers()

    def reset_optimizers(self):
        self.policy_optimizer = torch.optim.SGD(self.policy_network.parameters(), lr=POLICY_LR)
        self.q1_optimizer = torch.optim.SGD(self.q1_network.parameters(), lr=Q_LR)
        self.q2_optimizer = torch.optim.SGD(self.q2_network.parameters(), lr=Q_LR)

In [19]:
replay_buffer = deque(maxlen=30_000)

In [20]:
env.action_space.n

4

In [21]:
env.observation_space.shape

(8,)

In [22]:
oss = env.observation_space.shape
if len(oss) != 1:
    raise RuntimeError(f'Unknown observation_space.shape: {oss}')
os_len = oss[0]
policy = Policy(os_len, env.action_space.n)

In [23]:
s = env.reset()
s

array([-0.00290499,  1.4156194 , -0.29424563,  0.20886293,  0.00337281,
        0.06665098,  0.        ,  0.        ], dtype=float32)

In [24]:
s = torch.tensor(s).reshape((1, -1))
s

tensor([[-0.0029,  1.4156, -0.2942,  0.2089,  0.0034,  0.0667,  0.0000,  0.0000]])

In [25]:
policy_output = policy.policy_network.forward(s)
policy_output

tensor([[0.2577, 0.2270, 0.2719, 0.2434]], grad_fn=<SoftmaxBackward0>)

In [26]:
action_weights = policy_output.reshape((-1,)).tolist()

In [27]:
list(range(4))

[0, 1, 2, 3]

In [28]:
action = random.choices(range(len(action_weights)), weights=action_weights)[0]
action

1

In [29]:
def action(policy, s):
    tensor_s = torch.tensor(s).reshape((1, -1))
    action_weights = policy.policy_network.forward(tensor_s).reshape((-1,)).tolist()
    action = random.choices(range(len(action_weights)), weights=action_weights)[0]
    return action

def step(initial_s, a, r, next_s, t, failed, limit):
    replay_buffer.append(SARS(initial_s, a, r, next_s, t, failed, limit))
    if RENDER:
        env.render()

In [30]:
run_episode(action, step, env, policy)

-79.45620529341322

In [31]:
len(replay_buffer)

69

In [32]:
replay_buffer

deque([SARS(state=array([-0.00662155,  1.3991531 , -0.67071474, -0.52300346,  0.0076796 ,
               0.15192688,  0.        ,  0.        ], dtype=float32), action=3, reward=-0.2625017425813201, next_state=array([-0.01317463,  1.3868008 , -0.66114223, -0.5490281 ,  0.01344991,
               0.11541824,  0.        ,  0.        ], dtype=float32), t=1, failed=False, limit=False),
       SARS(state=array([-0.01317463,  1.3868008 , -0.66114223, -0.5490281 ,  0.01344991,
               0.11541824,  0.        ,  0.        ], dtype=float32), action=0, reward=-1.021086704306839, next_state=array([-0.01972799,  1.3738481 , -0.6611601 , -0.5757368 ,  0.01921695,
               0.11535144,  0.        ,  0.        ], dtype=float32), t=2, failed=False, limit=False),
       SARS(state=array([-0.01972799,  1.3738481 , -0.6611601 , -0.5757368 ,  0.01921695,
               0.11535144,  0.        ,  0.        ], dtype=float32), action=1, reward=-1.734157630213075, next_state=array([-0.02634163,  1.36

# Polyak Averaging

In [33]:
test_parameter_1 = next(policy.policy_network.named_parameters())[1]
test_parameter_1

Parameter containing:
tensor([[-0.3277,  0.2090,  0.3297,  ...,  0.2051, -0.2188,  0.2909],
        [ 0.2528, -0.3106, -0.3053,  ..., -0.2703,  0.2191,  0.3401],
        [ 0.0430, -0.0486,  0.1010,  ...,  0.1192, -0.3395,  0.0166],
        ...,
        [ 0.0794,  0.2144,  0.1901,  ...,  0.0754, -0.0803,  0.3365],
        [ 0.1511,  0.1732, -0.2590,  ...,  0.2546, -0.0539, -0.2189],
        [-0.3249,  0.1417,  0.0578,  ..., -0.2527, -0.0187,  0.0227]],
       requires_grad=True)

In [34]:
test_parameter_2 = test_parameter_1 * 0 + 0.0128
test_parameter_2

tensor([[0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        ...,
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128]],
       grad_fn=<AddBackward0>)

In [35]:
test_parameter_1 * 0.9 + test_parameter_2 * 0.1

tensor([[-0.2936,  0.1894,  0.2980,  ...,  0.1859, -0.1956,  0.2631],
        [ 0.2288, -0.2783, -0.2735,  ..., -0.2420,  0.1985,  0.3073],
        [ 0.0400, -0.0424,  0.0922,  ...,  0.1086, -0.3043,  0.0162],
        ...,
        [ 0.0727,  0.1942,  0.1724,  ...,  0.0691, -0.0710,  0.3041],
        [ 0.1373,  0.1571, -0.2318,  ...,  0.2304, -0.0472, -0.1957],
        [-0.2911,  0.1288,  0.0533,  ..., -0.2262, -0.0156,  0.0217]],
       grad_fn=<AddBackward0>)

In [36]:
def polyak_update(network_to_update, target_network, tau=0.001):
    with torch.no_grad():
        for to_update, target in zip(network_to_update.parameters(), target_network.parameters()):
            to_update *= 1-tau
            to_update += target * tau

In [37]:
test_network_1 = QNetwork(5, 3)
test_network_2 = QNetwork(5, 3)
display(list(test_network_1.parameters())[0])
display(list(test_network_2.parameters())[0])
polyak_update(test_network_2, test_network_1, 0.1)
display(list(test_network_1.parameters())[0])
display(list(test_network_2.parameters())[0])

Parameter containing:
tensor([[ 3.1797e-01, -7.9571e-02,  2.9746e-01,  2.7632e-04,  2.9593e-01],
        [ 5.1364e-02, -1.1621e-01,  2.5640e-01, -2.1581e-01, -1.0060e-02],
        [-1.1234e-01, -6.9210e-02,  3.7045e-01,  4.1787e-01, -3.0456e-02],
        ...,
        [ 3.3643e-01,  2.6751e-01, -4.3720e-01, -8.4023e-02,  3.1005e-01],
        [-2.2432e-02,  4.0653e-01, -2.9094e-01, -1.0232e-01,  2.1357e-01],
        [-1.3168e-01,  4.3242e-01,  2.6597e-01,  9.5947e-02, -1.5247e-01]],
       requires_grad=True)

Parameter containing:
tensor([[-0.2534, -0.1465, -0.2794,  0.2493,  0.3167],
        [ 0.4129, -0.2002,  0.4287, -0.2244,  0.2414],
        [ 0.3810,  0.3812, -0.1373,  0.3810, -0.4304],
        ...,
        [ 0.4128, -0.2616,  0.1254,  0.2314,  0.3670],
        [ 0.2717, -0.0840,  0.3860, -0.2561,  0.2265],
        [-0.2464, -0.0903, -0.3841,  0.2254, -0.3334]], requires_grad=True)

Parameter containing:
tensor([[ 3.1797e-01, -7.9571e-02,  2.9746e-01,  2.7632e-04,  2.9593e-01],
        [ 5.1364e-02, -1.1621e-01,  2.5640e-01, -2.1581e-01, -1.0060e-02],
        [-1.1234e-01, -6.9210e-02,  3.7045e-01,  4.1787e-01, -3.0456e-02],
        ...,
        [ 3.3643e-01,  2.6751e-01, -4.3720e-01, -8.4023e-02,  3.1005e-01],
        [-2.2432e-02,  4.0653e-01, -2.9094e-01, -1.0232e-01,  2.1357e-01],
        [-1.3168e-01,  4.3242e-01,  2.6597e-01,  9.5947e-02, -1.5247e-01]],
       requires_grad=True)

Parameter containing:
tensor([[-0.1963, -0.1398, -0.2217,  0.2244,  0.3147],
        [ 0.3768, -0.1918,  0.4115, -0.2235,  0.2162],
        [ 0.3317,  0.3361, -0.0865,  0.3847, -0.3904],
        ...,
        [ 0.4051, -0.2087,  0.0691,  0.1999,  0.3613],
        [ 0.2423, -0.0349,  0.3183, -0.2407,  0.2252],
        [-0.2349, -0.0380, -0.3191,  0.2125, -0.3153]], requires_grad=True)

# Log Experiments

In [38]:
for p in [0.99, 0.9, 0.8, 0.6, 0.5]:
    logs = -np.log([p, 1-p])
    display(p, logs, sum(logs))

0.99

array([0.01005034, 4.60517019])

4.615220521841592

0.9

array([0.10536052, 2.30258509])

2.4079456086518722

0.8

array([0.22314355, 1.60943791])

1.8325814637483102

0.6

array([0.51082562, 0.91629073])

1.4271163556401456

0.5

array([0.69314718, 0.69314718])

1.3862943611198906

In [39]:
random.sample(replay_buffer, k=4)

[SARS(state=array([-0.43443236,  0.08218633, -0.50153446, -1.1220052 , -0.39559972,
         0.14371787,  1.        ,  0.        ], dtype=float32), action=0, reward=-100, next_state=array([-0.43789187,  0.05987457, -0.3710753 , -0.6972187 , -0.30030096,
         4.261471  ,  1.        ,  0.        ], dtype=float32), t=69, failed=True, limit=False),
 SARS(state=array([-0.2333026 ,  0.834262  , -0.6448106 , -0.8520876 , -0.10108083,
        -0.11448674,  0.        ,  0.        ], dtype=float32), action=0, reward=-0.9908267252213534, next_state=array([-0.23980837,  0.8144989 , -0.64481103, -0.87875706, -0.10680516,
        -0.11448647,  0.        ,  0.        ], dtype=float32), t=35, failed=False, limit=False),
 SARS(state=array([-0.14758578,  1.0705919 , -0.65535915, -0.8570098 ,  0.00884818,
        -0.16253468,  0.        ,  0.        ], dtype=float32), action=2, reward=2.7749922970558147, next_state=array([-1.5422793e-01,  1.0514498e+00, -6.5598470e-01, -8.5073107e-01,
         7.1310

# Training

![Psudocode](sac_psudocode.png)

Source: https://spinningup.openai.com/en/latest/algorithms/sac.html#pseudocode

In [40]:
def q_min(q1, q2, states):
    
    def f(q):
        state_values = q.forward(states).detach()
        return state_values
        
    return torch.minimum(*map(f, (q1, q2)))

In [41]:
    stats = {}
    # Step 11
    training_batch = random.sample(replay_buffer, k=min(len(replay_buffer), 100))
    # Prep
    states = torch.tensor(np.array([sars.state for sars in training_batch]), requires_grad=False)
    actions = torch.tensor(np.array([sars.action for sars in training_batch]), requires_grad=False)
    actions_hot = nn.functional.one_hot(actions)
    rewards = torch.tensor(np.array([sars.reward for sars in training_batch]), requires_grad=False)
    next_states = torch.tensor(np.array([sars.next_state for sars in training_batch]), requires_grad=False)
    fails = torch.tensor(np.array([sars.failed for sars in training_batch]), dtype=int, requires_grad=False)
    # Step 12
    next_action_probs = policy.policy_network.forward(next_states).detach()
    assert not next_action_probs.requires_grad
    next_states_q_min = q_min(policy.q1_target_network, policy.q2_target_network, next_states)
    assert not next_states_q_min.requires_grad
    next_actions_q_min = torch.sum(next_states_q_min * next_action_probs, 1)
    assert not next_actions_q_min.requires_grad
    next_actions_entropy = torch.sum(next_action_probs * torch.log(next_action_probs), 1)
    assert not next_actions_entropy.requires_grad
    y = rewards + GAMMA * (1-fails) * (next_actions_q_min - ALPHA * next_actions_entropy)
    assert not y.requires_grad
    # Step 13
    for qi, q, opt in ((1, policy.q1_network, policy.q1_optimizer),
                       (2, policy.q2_network, policy.q2_optimizer)):
        assert not states.requires_grad
        assert not actions_hot.requires_grad
        q_state_action = torch.sum(q.forward(states) * actions_hot, 1)
        assert q_state_action.requires_grad
        q_loss = torch.mean((q_state_action - y)**2)
        stats[f'train/q_loss_{qi}'] = q_loss
        assert q_loss.requires_grad
        opt.zero_grad()
        q_loss.backward()
        opt.step()
    # Step 14
    action_probs = policy.policy_network.forward(states)
    assert action_probs.requires_grad
    states_q_min = q_min(policy.q1_network, policy.q2_network, states)
    assert not states_q_min.requires_grad
    actions_q_min = torch.sum(states_q_min * action_probs, 1)
    assert actions_q_min.requires_grad
    actions_entropy = torch.sum(action_probs * torch.log(action_probs), 1)
    assert actions_entropy.requires_grad
    policy_loss = -1 * torch.mean(actions_q_min - ALPHA * actions_entropy)
    stats['train/policy_loss'] = policy_loss
    policy.policy_optimizer.zero_grad()
    policy_loss.backward()
    policy.policy_optimizer.step()
    # Step 15
    polyak_update(policy.q1_target_network, policy.q1_network, tau=TAU)
    polyak_update(policy.q2_target_network, policy.q2_network, tau=TAU)
    stats

{'train/q_loss_1': tensor(150.5596, dtype=torch.float64, grad_fn=<MeanBackward0>),
 'train/q_loss_2': tensor(150.1703, dtype=torch.float64, grad_fn=<MeanBackward0>),
 'train/policy_loss': tensor(-0.3494, grad_fn=<MulBackward0>)}

In [42]:
def train(policy, replay_buffer):
    stats = {}
    # Step 11
    training_batch = random.sample(replay_buffer, k=min(len(replay_buffer), 100))
    # Prep
    states = torch.tensor(np.array([sars.state for sars in training_batch]), requires_grad=False)
    actions = torch.tensor(np.array([sars.action for sars in training_batch]), requires_grad=False)
    actions_hot = nn.functional.one_hot(actions)
    rewards = torch.tensor(np.array([sars.reward for sars in training_batch]), requires_grad=False)
    next_states = torch.tensor(np.array([sars.next_state for sars in training_batch]), requires_grad=False)
    fails = torch.tensor(np.array([sars.failed for sars in training_batch]), dtype=int, requires_grad=False)
    # Step 12
    next_action_probs = policy.policy_network.forward(next_states).detach()
    assert not next_action_probs.requires_grad
    next_states_q_min = q_min(policy.q1_target_network, policy.q2_target_network, next_states)
    assert not next_states_q_min.requires_grad
    next_actions_q_min = torch.sum(next_states_q_min * next_action_probs, 1)
    assert not next_actions_q_min.requires_grad
    next_actions_entropy = torch.sum(next_action_probs * torch.log(next_action_probs), 1)
    assert not next_actions_entropy.requires_grad
    y = rewards + GAMMA * (1-fails) * (next_actions_q_min - ALPHA * next_actions_entropy)
    assert not y.requires_grad
    # Step 13
    for qi, q, opt in ((1, policy.q1_network, policy.q1_optimizer),
                       (2, policy.q2_network, policy.q2_optimizer)):
        assert not states.requires_grad
        assert not actions_hot.requires_grad
        try:
            # Been getting some unexpected errors on this line, so logging exception details.
            q_state_action = torch.sum(q.forward(states) * actions_hot, 1)
        except RuntimeError:
            display(states.shape)
            display(actions_hot.shape)
        assert q_state_action.requires_grad
        q_loss = torch.mean((q_state_action - y)**2)
        stats[f'train/q_loss_{qi}'] = q_loss
        assert q_loss.requires_grad
        opt.zero_grad()
        q_loss.backward()
        opt.step()
    # Step 14
    action_probs = policy.policy_network.forward(states)
    assert action_probs.requires_grad
    states_q_min = q_min(policy.q1_network, policy.q2_network, states)
    assert not states_q_min.requires_grad
    actions_q_min = torch.sum(states_q_min * action_probs, 1)
    assert actions_q_min.requires_grad
    actions_entropy = torch.sum(action_probs * torch.log(action_probs), 1)
    assert actions_entropy.requires_grad
    policy_loss = -1 * torch.mean(actions_q_min - ALPHA * actions_entropy)
    stats['train/policy_loss'] = policy_loss
    policy.policy_optimizer.zero_grad()
    policy_loss.backward()
    policy.policy_optimizer.step()
    # Step 15
    polyak_update(policy.q1_target_network, policy.q1_network, tau=TAU)
    polyak_update(policy.q2_target_network, policy.q2_network, tau=TAU)
    return stats

In [None]:
tb_writer = SummaryWriter()

oss = env.observation_space.shape
if len(oss) != 1:
    raise RuntimeError(f'Unknown observation_space.shape: {oss}')
os_len = oss[0]
policy = Policy(os_len, env.action_space.n)

replay_buffer = deque(maxlen=30_000)

def action(policy, s):
    tensor_s = torch.tensor(s).reshape((1, -1))
    action_weights = policy.policy_network.forward(tensor_s).reshape((-1,)).tolist()
    action = random.choices(range(len(action_weights)), weights=action_weights)[0]
    return action

def step(initial_s, a, r, next_s, t, failed, limit):
    replay_buffer.append(SARS(initial_s, a, r, next_s, t, failed, limit))
    if RENDER:
        env.render()

for episode in tqdm.tqdm(range(1, 1000+1)):
    episode_reward = run_episode(action, step, env, policy, fail_at_limit=True)
    tb_writer.add_scalar('main/episode_reward', episode_reward, episode)
    tb_writer.add_scalar('main/replay_buffer_length', len(replay_buffer), episode)
    policy.reset_optimizers()
    for training_iteration in range(1, 100+1):
        stats = train(policy, replay_buffer)
        for stat, value in stats.items():
            tb_writer.add_scalar(stat, value, episode)

 40%|██████████████████████████████████████████████                                                                     | 400/1000 [06:53<14:32,  1.45s/it]