In [1]:
import gym
import torch
from torch import nn
from collections import namedtuple, deque
import itertools
from copy import deepcopy
import random
import numpy as np

  import distutils.spawn


In [2]:
RENDER = False

In [3]:
ALPHA = 0.8  # How much we value entropy / exploration, increasing this will increase exploration.
GAMMA = 0.99  # How much we value future rewards.
TAU = 0.001  # How much q_target is updated when polyak averaging (step 15).
LR = 0.001  # Optimizer learning rate.

In [4]:
env = gym.make("CartPole-v1")

In [5]:
SARS = namedtuple('SARS', 'state, action, reward, next_state, t, failed, limit')

In [6]:
softmax = nn.Softmax(dim=0)
input = torch.tensor([1, 2, 3], dtype=float)
display(input)
output = softmax(input)
display(output)
sum(output)

tensor([1., 2., 3.], dtype=torch.float64)

tensor([0.0900, 0.2447, 0.6652], dtype=torch.float64)

tensor(1.0000, dtype=torch.float64)

In [7]:
softmax = nn.Softmax(dim=1)
input = torch.tensor([[1, 2, 3], [1, 2, 3], [3, 3, 3]], dtype=float)
display(input)
output = softmax(input)
display(output)
sum(output)

tensor([[1., 2., 3.],
        [1., 2., 3.],
        [3., 3., 3.]], dtype=torch.float64)

tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652],
        [0.3333, 0.3333, 0.3333]], dtype=torch.float64)

tensor([0.5134, 0.8228, 1.6638], dtype=torch.float64)

In [8]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, 600),
            nn.ReLU(),
            nn.Linear(600, 200),
            nn.ReLU(),
            nn.Linear(200, output_size)
        )

    def forward(self, x):
        nn_out = self.linear_relu_stack(x)
        return nn.Softmax(dim=1)(nn_out)

    def __call__(self, x):
        raise RuntimeError("Use forward")

In [9]:
policy_network = PolicyNetwork(4, 2)
policy_network

PolicyNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=4, out_features=600, bias=True)
    (1): ReLU()
    (2): Linear(in_features=600, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=2, bias=True)
  )
)

In [10]:
torch.cuda.is_available()

True

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [12]:
mock_states = torch.rand(5, 4)
mock_states

tensor([[0.2276, 0.1398, 0.8639, 0.4408],
        [0.7641, 0.8586, 0.2373, 0.2574],
        [0.7042, 0.0842, 0.7812, 0.9577],
        [0.6487, 0.5499, 0.1303, 0.1788],
        [0.8241, 0.4027, 0.3977, 0.8717]])

In [13]:
policy_network.forward(mock_states)

tensor([[0.5134, 0.4866],
        [0.4818, 0.5182],
        [0.5123, 0.4877],
        [0.4909, 0.5091],
        [0.4911, 0.5089]], grad_fn=<SoftmaxBackward0>)

In [14]:
class QNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, 600),
            nn.ReLU(),
            nn.Linear(600, 200),
            nn.ReLU(),
            nn.Linear(200, output_size)
        )

    def forward(self, x):
        nn_out = self.linear_relu_stack(x)
        return nn_out
    
    def __call__(self, x):
        raise RuntimeError("Use forward")

In [15]:
q_network = QNetwork(4, 2)
q_network

QNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=4, out_features=600, bias=True)
    (1): ReLU()
    (2): Linear(in_features=600, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=2, bias=True)
  )
)

In [16]:
q_network.forward(mock_states)

tensor([[ 0.0154, -0.1036],
        [-0.0239, -0.0566],
        [ 0.0092, -0.1231],
        [-0.0412, -0.0572],
        [-0.0089, -0.1071]], grad_fn=<AddmmBackward0>)

In [17]:
def run_episode(action_f, step_f, env, policy):
    episode_reward = 0
    s = env.reset()
    for t in itertools.count(start=0):
        a = action_f(policy, s)
        next_state, reward, failed, info = env.step(a)
        episode_reward += reward
        assert t <= env._max_episode_steps
        limit = t == env._max_episode_steps
        if limit:
            failed = False
        assert not (limit and failed)
        step_f(s, a, reward, next_state, t, failed, limit)
        if failed or limit:
            break
        s = next_state

In [18]:
class Policy:
    def __init__(self, env_state_size, env_action_space_size):
        self.policy_network = PolicyNetwork(env_state_size, env_action_space_size)
        self.q1_network = QNetwork(env_state_size, env_action_space_size)
        self.q2_network = QNetwork(env_state_size, env_action_space_size)
        self.q1_target_network = deepcopy(self.q1_network)
        self.q2_target_network = deepcopy(self.q2_network)
        self.policy_optimizer = torch.optim.Adam(self.policy_network.parameters(), lr=LR)
        self.q1_optimizer = torch.optim.Adam(self.q1_network.parameters(), lr=LR)
        self.q2_optimizer = torch.optim.Adam(self.q2_network.parameters(), lr=LR)

In [19]:
replay_buffer = deque(maxlen=10_000)

In [20]:
env.action_space.n

2

In [21]:
env.observation_space.shape

(4,)

In [22]:
oss = env.observation_space.shape
if len(oss) != 1:
    raise RuntimeError(f'Unknown observation_space.shape: {oss}')
os_len = oss[0]
policy = Policy(os_len, env.action_space.n)

In [23]:
s = env.reset()
s

array([ 0.02057485, -0.01812934,  0.02035434, -0.02117978], dtype=float32)

In [24]:
s = torch.tensor(s).reshape((1, -1))
s

tensor([[ 0.0206, -0.0181,  0.0204, -0.0212]])

In [25]:
policy_output = policy.policy_network.forward(s)
policy_output

tensor([[0.5068, 0.4932]], grad_fn=<SoftmaxBackward0>)

In [26]:
action_weights = policy_output.reshape((-1,)).tolist()

In [27]:
list(range(4))

[0, 1, 2, 3]

In [28]:
action = random.choices(range(len(action_weights)), weights=action_weights)[0]
action

0

In [29]:
def action(policy, s):
    tensor_s = torch.tensor(s).reshape((1, -1))
    action_weights = policy.policy_network.forward(tensor_s).reshape((-1,)).tolist()
    action = random.choices(range(len(action_weights)), weights=action_weights)[0]
    return action

def step(initial_s, a, r, next_s, t, failed, limit):
    replay_buffer.append(SARS(initial_s, a, r, next_s, t, failed, limit))
    if RENDER:
        env.render()

In [30]:
run_episode(action, step, env, policy)

In [31]:
len(replay_buffer)

38

In [32]:
replay_buffer

deque([SARS(state=array([-0.03245957, -0.02608964,  0.01833498,  0.0409357 ], dtype=float32), action=0, reward=1.0, next_state=array([-0.03298136, -0.22146964,  0.0191537 ,  0.33934665], dtype=float32), t=0, failed=False, limit=False),
       SARS(state=array([-0.03298136, -0.22146964,  0.0191537 ,  0.33934665], dtype=float32), action=1, reward=1.0, next_state=array([-0.03741075, -0.0266254 ,  0.02594063,  0.05276472], dtype=float32), t=1, failed=False, limit=False),
       SARS(state=array([-0.03741075, -0.0266254 ,  0.02594063,  0.05276472], dtype=float32), action=0, reward=1.0, next_state=array([-0.03794326, -0.22210951,  0.02699593,  0.35351792], dtype=float32), t=2, failed=False, limit=False),
       SARS(state=array([-0.03794326, -0.22210951,  0.02699593,  0.35351792], dtype=float32), action=1, reward=1.0, next_state=array([-0.04238545, -0.02738163,  0.03406629,  0.06946837], dtype=float32), t=3, failed=False, limit=False),
       SARS(state=array([-0.04238545, -0.02738163,  0.03

# Polyak Averaging

In [33]:
test_parameter_1 = next(policy.policy_network.named_parameters())[1]
test_parameter_1

Parameter containing:
tensor([[ 0.0737, -0.4976,  0.2914,  0.3621],
        [ 0.3490,  0.1600,  0.1528, -0.0690],
        [ 0.2899,  0.4529,  0.1889, -0.0258],
        ...,
        [-0.1919, -0.1023, -0.3327,  0.2168],
        [ 0.2567,  0.0165, -0.4353, -0.2538],
        [ 0.3787,  0.0832,  0.1555,  0.0082]], requires_grad=True)

In [34]:
test_parameter_2 = test_parameter_1 * 0 + 0.0128
test_parameter_2

tensor([[0.0128, 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128, 0.0128],
        ...,
        [0.0128, 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128, 0.0128]], grad_fn=<AddBackward0>)

In [35]:
test_parameter_1 * 0.9 + test_parameter_2 * 0.1

tensor([[ 0.0676, -0.4466,  0.2635,  0.3272],
        [ 0.3154,  0.1453,  0.1388, -0.0608],
        [ 0.2622,  0.4089,  0.1713, -0.0219],
        ...,
        [-0.1715, -0.0908, -0.2981,  0.1964],
        [ 0.2323,  0.0161, -0.3905, -0.2271],
        [ 0.3421,  0.0762,  0.1413,  0.0086]], grad_fn=<AddBackward0>)

In [36]:
def polyak_update(network_to_update, target_network, tau=0.001):
    with torch.no_grad():
        for to_update, target in zip(network_to_update.parameters(), target_network.parameters()):
            to_update *= 1-tau
            to_update += target * tau

In [37]:
test_network_1 = QNetwork(5, 3)
test_network_2 = QNetwork(5, 3)
display(list(test_network_1.parameters())[0])
display(list(test_network_2.parameters())[0])
polyak_update(test_network_2, test_network_1, 0.1)
display(list(test_network_1.parameters())[0])
display(list(test_network_2.parameters())[0])

Parameter containing:
tensor([[ 0.0298,  0.0279, -0.0442,  0.0963, -0.0846],
        [ 0.1909,  0.2909, -0.3064, -0.3638, -0.0099],
        [ 0.1050,  0.2826,  0.0005,  0.2853, -0.1424],
        ...,
        [-0.1232, -0.3184,  0.1565,  0.2264, -0.3666],
        [ 0.2386, -0.2276,  0.3120,  0.1630, -0.0011],
        [ 0.1266, -0.2548, -0.2124, -0.4007,  0.0891]], requires_grad=True)

Parameter containing:
tensor([[ 0.2674, -0.2624,  0.1772, -0.1483, -0.1299],
        [ 0.4193,  0.4128,  0.3540, -0.2280, -0.0863],
        [ 0.4243, -0.1306,  0.2992,  0.4241, -0.0326],
        ...,
        [ 0.0617, -0.3670,  0.1081, -0.1708, -0.4075],
        [-0.3419, -0.3378,  0.1305, -0.3625, -0.4252],
        [-0.2501, -0.1479, -0.1584, -0.2678, -0.3493]], requires_grad=True)

Parameter containing:
tensor([[ 0.0298,  0.0279, -0.0442,  0.0963, -0.0846],
        [ 0.1909,  0.2909, -0.3064, -0.3638, -0.0099],
        [ 0.1050,  0.2826,  0.0005,  0.2853, -0.1424],
        ...,
        [-0.1232, -0.3184,  0.1565,  0.2264, -0.3666],
        [ 0.2386, -0.2276,  0.3120,  0.1630, -0.0011],
        [ 0.1266, -0.2548, -0.2124, -0.4007,  0.0891]], requires_grad=True)

Parameter containing:
tensor([[ 0.2436, -0.2334,  0.1550, -0.1238, -0.1253],
        [ 0.3965,  0.4006,  0.2879, -0.2416, -0.0786],
        [ 0.3924, -0.0893,  0.2693,  0.4102, -0.0436],
        ...,
        [ 0.0432, -0.3621,  0.1129, -0.1311, -0.4034],
        [-0.2838, -0.3268,  0.1487, -0.3100, -0.3828],
        [-0.2125, -0.1586, -0.1638, -0.2811, -0.3055]], requires_grad=True)

# Log Experiments

In [38]:
for p in [0.99, 0.9, 0.8, 0.6, 0.5]:
    logs = -np.log([p, 1-p])
    display(p, logs, sum(logs))

0.99

array([0.01005034, 4.60517019])

4.615220521841592

0.9

array([0.10536052, 2.30258509])

2.4079456086518722

0.8

array([0.22314355, 1.60943791])

1.8325814637483102

0.6

array([0.51082562, 0.91629073])

1.4271163556401456

0.5

array([0.69314718, 0.69314718])

1.3862943611198906

In [39]:
random.sample(replay_buffer, k=4)

[SARS(state=array([ 0.03807039,  0.35416862, -0.02907363, -0.32429612], dtype=float32), action=1, reward=1.0, next_state=array([ 0.04515376,  0.5496922 , -0.03555956, -0.62600404], dtype=float32), t=24, failed=False, limit=False),
 SARS(state=array([ 0.13356096,  0.55831546, -0.14611018, -0.8238011 ], dtype=float32), action=0, reward=1.0, next_state=array([ 0.14472726,  0.36546195, -0.1625862 , -0.58040684], dtype=float32), t=33, failed=False, limit=False),
 SARS(state=array([-0.04522029, -0.2249754 ,  0.04314569,  0.41696075], dtype=float32), action=1, reward=1.0, next_state=array([-0.0497198 , -0.03049062,  0.0514849 ,  0.13818595], dtype=float32), t=9, failed=False, limit=False),
 SARS(state=array([-0.00117357,  0.54845834,  0.0075393 , -0.5985675 ], dtype=float32), action=0, reward=1.0, next_state=array([ 0.0097956 ,  0.3532317 , -0.00443205, -0.30351934], dtype=float32), t=19, failed=False, limit=False)]

# Training

![Psudocode](sac_psudocode.png)

Source: https://spinningup.openai.com/en/latest/algorithms/sac.html#pseudocode

In [40]:
def q_min(q1, q2, states, actions_hot):
    
    def f(q):
        state_values = q.forward(states).detach()
        chosen_action_values = torch.sum(state_values * actions_hot, 1)
        return chosen_action_values
        
    return torch.minimum(*map(f, (q1, q2)))

In [41]:
# Step 11
training_batch = random.sample(replay_buffer, k=min(len(replay_buffer), 100))
# Step 12
states = torch.tensor(np.array([sars.state for sars in training_batch]), requires_grad=False)
actions = torch.tensor(np.array([sars.action for sars in training_batch]), requires_grad=False)
actions_hot = nn.functional.one_hot(actions)
rewards = torch.tensor(np.array([sars.reward for sars in training_batch]), requires_grad=False)
next_states = torch.tensor(np.array([sars.next_state for sars in training_batch]), requires_grad=False)
fails = torch.tensor(np.array([sars.failed for sars in training_batch]), dtype=int, requires_grad=False)
next_action_probs = policy.policy_network.forward(next_states).detach()

def to_action_probs(probs):
    action_index = random.choices(range(len(probs)), weights=probs)[0]
    a = action_index
    p = probs[action_index]
    return [a, p]

next_actions_with_probs = np.apply_along_axis(to_action_probs, 1, next_action_probs.numpy())
# The next action chosen by the policy network
next_actions = torch.tensor(next_actions_with_probs[:, 0], dtype=int, requires_grad=False)
next_actions_hot = nn.functional.one_hot(next_actions)
# The probability of the next action chosen by the policy network
next_probs = torch.tensor(next_actions_with_probs[:, 1], requires_grad=False)
sampled_next_action_q_min = q_min(
    policy.q1_target_network,
    policy.q2_target_network,
    next_states,
    next_actions_hot)
sampled_next_action_prob = torch.sum(next_action_probs * next_actions_hot, 1)
sampled_next_action_log_prob = torch.log(sampled_next_action_prob)
y = rewards + GAMMA * (1-fails) * (sampled_next_action_q_min - ALPHA * sampled_next_action_log_prob)
# Step 13
for tensor in (states, actions_hot, y):
    assert not tensor.requires_grad
for q, opt in ((policy.q1_network, policy.q1_optimizer),
               (policy.q2_network, policy.q2_optimizer)):
    opt.zero_grad()
    q_state_action = torch.sum(q.forward(states) * actions_hot, 1)
    assert q_state_action.requires_grad
    q_loss = torch.mean((q_state_action - y)**2)
    q_loss.backward()
    opt.step()
# Step 14
policy.policy_optimizer.zero_grad()
action_probs = policy.policy_network.forward(states)
assert action_probs.requires_grad
sampled_action_probs = torch.tensor(np.apply_along_axis(to_action_probs, 1, action_probs.detach().numpy()))
assert not sampled_action_probs.requires_grad
sampled_actions = sampled_action_probs[:, 0].long()
assert not sampled_actions.requires_grad
sampled_actions_hot = nn.functional.one_hot(sampled_actions)
assert not sampled_actions_hot.requires_grad
sampled_probs = torch.sum(action_probs * sampled_actions_hot, 1)
assert sampled_probs.requires_grad
sampled_log_probs = torch.log(sampled_probs)
assert sampled_log_probs.requires_grad
sampled_action_q_min = q_min(policy.q1_network, policy.q2_network, states, sampled_actions.reshape((-1, 1)))
assert not sampled_action_q_min.requires_grad
policy_loss = torch.mean(sampled_action_q_min - ALPHA * sampled_log_probs)
policy_loss.backward()
policy.policy_optimizer.step()
# Step 15
polyak_update(policy.q1_target_network, policy.q1_network, tau=TAU)
polyak_update(policy.q2_target_network, policy.q2_network, tau=TAU)

In [42]:
def train(policy, replay_buffer):
    # Step 11
    training_batch = random.sample(replay_buffer, k=min(len(replay_buffer), 100))
    # Step 12
    states = torch.tensor(np.array([sars.state for sars in training_batch]), requires_grad=False)
    actions = torch.tensor(np.array([sars.action for sars in training_batch]), requires_grad=False)
    actions_hot = nn.functional.one_hot(actions)
    rewards = torch.tensor(np.array([sars.reward for sars in training_batch]), requires_grad=False)
    next_states = torch.tensor(np.array([sars.next_state for sars in training_batch]), requires_grad=False)
    fails = torch.tensor(np.array([sars.failed for sars in training_batch]), dtype=int, requires_grad=False)
    next_action_probs = policy.policy_network.forward(next_states).detach()

    def to_action_probs(probs):
        action_index = random.choices(range(len(probs)), weights=probs)[0]
        a = action_index
        p = probs[action_index]
        return [a, p]

    next_actions_with_probs = np.apply_along_axis(to_action_probs, 1, next_action_probs.numpy())
    # The next action chosen by the policy network
    next_actions = torch.tensor(next_actions_with_probs[:, 0], dtype=int, requires_grad=False)
    next_actions_hot = nn.functional.one_hot(next_actions)
    # The probability of the next action chosen by the policy network
    next_probs = torch.tensor(next_actions_with_probs[:, 1], requires_grad=False)
    sampled_next_action_q_min = q_min(
        policy.q1_target_network,
        policy.q2_target_network,
        next_states,
        next_actions_hot)
    sampled_next_action_prob = torch.sum(next_action_probs * next_actions_hot, 1)
    sampled_next_action_log_prob = torch.log(sampled_next_action_prob)
    y = rewards + GAMMA * (1-fails) * (sampled_next_action_q_min - ALPHA * sampled_next_action_log_prob)
    # Step 13
    for tensor in (states, actions_hot, y):
        assert not tensor.requires_grad
    for q, opt in ((policy.q1_network, policy.q1_optimizer),
                   (policy.q2_network, policy.q2_optimizer)):
        opt.zero_grad()
        q_state_action = torch.sum(q.forward(states) * actions_hot, 1)
        assert q_state_action.requires_grad
        q_loss = torch.mean((q_state_action - y)**2)
        q_loss.backward()
        opt.step()
    # Step 14
    policy.policy_optimizer.zero_grad()
    action_probs = policy.policy_network.forward(states)
    assert action_probs.requires_grad
    sampled_action_probs = torch.tensor(np.apply_along_axis(to_action_probs, 1, action_probs.detach().numpy()))
    assert not sampled_action_probs.requires_grad
    sampled_actions = sampled_action_probs[:, 0].long()
    assert not sampled_actions.requires_grad
    sampled_actions_hot = nn.functional.one_hot(sampled_actions)
    assert not sampled_actions_hot.requires_grad
    sampled_probs = torch.sum(action_probs * sampled_actions_hot, 1)
    assert sampled_probs.requires_grad
    sampled_log_probs = torch.log(sampled_probs)
    assert sampled_log_probs.requires_grad
    sampled_action_q_min = q_min(policy.q1_network, policy.q2_network, states, sampled_actions.reshape((-1, 1)))
    assert not sampled_action_q_min.requires_grad
    policy_loss = torch.mean(sampled_action_q_min - ALPHA * sampled_log_probs)
    policy_loss.backward()
    policy.policy_optimizer.step()
    # Step 15
    polyak_update(policy.q1_target_network, policy.q1_network, tau=TAU)
    polyak_update(policy.q2_target_network, policy.q2_network, tau=TAU)