In [1]:
import gym
import torch
from torch import nn
from collections import namedtuple, deque
import itertools
from copy import deepcopy
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import tqdm

  import distutils.spawn


In [2]:
RENDER = False

In [3]:
ALPHA = 0.3  # How much we value entropy / exploration, increasing this will increase exploration.
GAMMA = 0.99  # How much we value future rewards.
TAU = 0.001  # How much q_target is updated when polyak averaging (step 15).
POLICY_LR = 0.001  # Policy learning rate.
Q_LR = 0.001  # Q learning rate.

In [4]:
env = gym.make("CartPole-v1")

In [5]:
SARS = namedtuple('SARS', 'state, action, reward, next_state, t, failed, limit')

In [6]:
softmax = nn.Softmax(dim=0)
input = torch.tensor([1, 2, 3], dtype=float)
display(input)
output = softmax(input)
display(output)
sum(output)

tensor([1., 2., 3.], dtype=torch.float64)

tensor([0.0900, 0.2447, 0.6652], dtype=torch.float64)

tensor(1.0000, dtype=torch.float64)

In [7]:
softmax = nn.Softmax(dim=1)
input = torch.tensor([[1, 2, 3], [1, 2, 3], [3, 3, 3]], dtype=float)
display(input)
output = softmax(input)
display(output)
sum(output)

tensor([[1., 2., 3.],
        [1., 2., 3.],
        [3., 3., 3.]], dtype=torch.float64)

tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652],
        [0.3333, 0.3333, 0.3333]], dtype=torch.float64)

tensor([0.5134, 0.8228, 1.6638], dtype=torch.float64)

In [8]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, 600),
            nn.ReLU(),
            nn.Linear(600, 200),
            nn.ReLU(),
            nn.Linear(200, output_size)
        )

    def forward(self, x):
        nn_out = self.linear_relu_stack(x)
        return nn.Softmax(dim=1)(nn_out)

    def __call__(self, x):
        raise RuntimeError("Use forward")

In [9]:
policy_network = PolicyNetwork(4, 2)
policy_network

PolicyNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=4, out_features=600, bias=True)
    (1): ReLU()
    (2): Linear(in_features=600, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=2, bias=True)
  )
)

In [10]:
torch.cuda.is_available()

True

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [12]:
mock_states = torch.rand(5, 4)
mock_states

tensor([[0.6190, 0.2480, 0.1420, 0.8914],
        [0.1900, 0.7467, 0.5392, 0.6186],
        [0.5108, 0.0135, 0.0980, 0.9990],
        [0.8812, 0.4234, 0.1902, 0.7100],
        [0.4196, 0.7565, 0.3873, 0.6235]])

In [13]:
policy_network.forward(mock_states)

tensor([[0.4756, 0.5244],
        [0.4884, 0.5116],
        [0.4798, 0.5202],
        [0.4785, 0.5215],
        [0.4795, 0.5205]], grad_fn=<SoftmaxBackward0>)

In [14]:
class QNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, 600),
            nn.ReLU(),
            nn.Linear(600, 200),
            nn.ReLU(),
            nn.Linear(200, output_size)
        )

    def forward(self, x):
        nn_out = self.linear_relu_stack(x)
        return nn_out
    
    def __call__(self, x):
        raise RuntimeError("Use forward")

In [15]:
q_network = QNetwork(4, 2)
q_network

QNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=4, out_features=600, bias=True)
    (1): ReLU()
    (2): Linear(in_features=600, out_features=200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=200, out_features=2, bias=True)
  )
)

In [16]:
q_network.forward(mock_states)

tensor([[ 0.0454, -0.1411],
        [-0.0014, -0.1465],
        [ 0.0607, -0.1517],
        [ 0.0396, -0.1380],
        [-0.0006, -0.1346]], grad_fn=<AddmmBackward0>)

In [17]:
def run_episode(action_f, step_f, env, policy):
    episode_reward = 0
    s = env.reset()
    for t in itertools.count(start=0):
        a = action_f(policy, s)
        next_state, reward, failed, info = env.step(a)
        episode_reward += reward
        assert t <= env._max_episode_steps
        limit = t == env._max_episode_steps
        if limit:
            failed = False
        assert not (limit and failed)
        step_f(s, a, reward, next_state, t, failed, limit)
        if failed or limit:
            break
        s = next_state
    return episode_reward

In [18]:
class Policy:
    def __init__(self, env_state_size, env_action_space_size):
        self.policy_network = PolicyNetwork(env_state_size, env_action_space_size)
        self.q1_network = QNetwork(env_state_size, env_action_space_size)
        self.q2_network = QNetwork(env_state_size, env_action_space_size)
        self.q1_target_network = deepcopy(self.q1_network)
        self.q2_target_network = deepcopy(self.q2_network)
        self.policy_optimizer = torch.optim.SGD(self.policy_network.parameters(), lr=POLICY_LR)
        self.q1_optimizer = torch.optim.SGD(self.q1_network.parameters(), lr=Q_LR)
        self.q2_optimizer = torch.optim.SGD(self.q2_network.parameters(), lr=Q_LR)

In [19]:
replay_buffer = deque(maxlen=10_000)

In [20]:
env.action_space.n

2

In [21]:
env.observation_space.shape

(4,)

In [22]:
oss = env.observation_space.shape
if len(oss) != 1:
    raise RuntimeError(f'Unknown observation_space.shape: {oss}')
os_len = oss[0]
policy = Policy(os_len, env.action_space.n)

In [23]:
s = env.reset()
s

array([-0.00700782, -0.03484046, -0.00119085,  0.02735358], dtype=float32)

In [24]:
s = torch.tensor(s).reshape((1, -1))
s

tensor([[-0.0070, -0.0348, -0.0012,  0.0274]])

In [25]:
policy_output = policy.policy_network.forward(s)
policy_output

tensor([[0.5146, 0.4854]], grad_fn=<SoftmaxBackward0>)

In [26]:
action_weights = policy_output.reshape((-1,)).tolist()

In [27]:
list(range(4))

[0, 1, 2, 3]

In [28]:
action = random.choices(range(len(action_weights)), weights=action_weights)[0]
action

0

In [29]:
def action(policy, s):
    tensor_s = torch.tensor(s).reshape((1, -1))
    action_weights = policy.policy_network.forward(tensor_s).reshape((-1,)).tolist()
    action = random.choices(range(len(action_weights)), weights=action_weights)[0]
    return action

def step(initial_s, a, r, next_s, t, failed, limit):
    replay_buffer.append(SARS(initial_s, a, r, next_s, t, failed, limit))
    if RENDER:
        env.render()

In [30]:
run_episode(action, step, env, policy)

9.0

In [31]:
len(replay_buffer)

9

In [32]:
replay_buffer

deque([SARS(state=array([ 0.00205541, -0.0131901 , -0.03513844, -0.00672308], dtype=float32), action=1, reward=1.0, next_state=array([ 0.00179161,  0.1824177 , -0.0352729 , -0.31028226], dtype=float32), t=0, failed=False, limit=False),
       SARS(state=array([ 0.00179161,  0.1824177 , -0.0352729 , -0.31028226], dtype=float32), action=1, reward=1.0, next_state=array([ 0.00543996,  0.37802398, -0.04147855, -0.61387724], dtype=float32), t=1, failed=False, limit=False),
       SARS(state=array([ 0.00543996,  0.37802398, -0.04147855, -0.61387724], dtype=float32), action=1, reward=1.0, next_state=array([ 0.01300044,  0.57370025, -0.0537561 , -0.9193304 ], dtype=float32), t=2, failed=False, limit=False),
       SARS(state=array([ 0.01300044,  0.57370025, -0.0537561 , -0.9193304 ], dtype=float32), action=1, reward=1.0, next_state=array([ 0.02447445,  0.769506  , -0.07214271, -1.2284114 ], dtype=float32), t=3, failed=False, limit=False),
       SARS(state=array([ 0.02447445,  0.769506  , -0.07

# Polyak Averaging

In [33]:
test_parameter_1 = next(policy.policy_network.named_parameters())[1]
test_parameter_1

Parameter containing:
tensor([[ 0.0357,  0.4357, -0.1708, -0.4519],
        [-0.1181, -0.0746, -0.2929, -0.0853],
        [-0.1354, -0.3029, -0.2852, -0.0292],
        ...,
        [-0.0623, -0.1166, -0.3547,  0.4216],
        [-0.3434,  0.3784,  0.1392, -0.0764],
        [ 0.3388, -0.0763,  0.0683, -0.2313]], requires_grad=True)

In [34]:
test_parameter_2 = test_parameter_1 * 0 + 0.0128
test_parameter_2

tensor([[0.0128, 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128, 0.0128],
        ...,
        [0.0128, 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128, 0.0128]], grad_fn=<AddBackward0>)

In [35]:
test_parameter_1 * 0.9 + test_parameter_2 * 0.1

tensor([[ 0.0334,  0.3934, -0.1524, -0.4054],
        [-0.1050, -0.0659, -0.2623, -0.0755],
        [-0.1206, -0.2713, -0.2554, -0.0250],
        ...,
        [-0.0548, -0.1037, -0.3179,  0.3808],
        [-0.3078,  0.3418,  0.1265, -0.0675],
        [ 0.3062, -0.0674,  0.0627, -0.2069]], grad_fn=<AddBackward0>)

In [36]:
def polyak_update(network_to_update, target_network, tau=0.001):
    with torch.no_grad():
        for to_update, target in zip(network_to_update.parameters(), target_network.parameters()):
            to_update *= 1-tau
            to_update += target * tau

In [37]:
test_network_1 = QNetwork(5, 3)
test_network_2 = QNetwork(5, 3)
display(list(test_network_1.parameters())[0])
display(list(test_network_2.parameters())[0])
polyak_update(test_network_2, test_network_1, 0.1)
display(list(test_network_1.parameters())[0])
display(list(test_network_2.parameters())[0])

Parameter containing:
tensor([[-0.1086, -0.3811,  0.4121,  0.4447, -0.0613],
        [-0.2336, -0.4341,  0.3314, -0.2447,  0.0519],
        [-0.2620,  0.0783,  0.3393, -0.1152,  0.1455],
        ...,
        [-0.3957, -0.2433,  0.4255, -0.0991, -0.3537],
        [ 0.0109, -0.2860,  0.1754, -0.2197, -0.1117],
        [-0.0800,  0.3955, -0.0851,  0.4172, -0.1765]], requires_grad=True)

Parameter containing:
tensor([[-0.3747, -0.3947,  0.4414, -0.0869,  0.3169],
        [-0.0810, -0.4107,  0.4362,  0.0486, -0.3011],
        [ 0.4085,  0.2738, -0.0793, -0.0953,  0.2122],
        ...,
        [-0.1386, -0.3274, -0.0685,  0.2992, -0.4003],
        [ 0.3086,  0.4255,  0.3919, -0.2712,  0.3654],
        [ 0.0938,  0.0681,  0.2434, -0.1826, -0.3493]], requires_grad=True)

Parameter containing:
tensor([[-0.1086, -0.3811,  0.4121,  0.4447, -0.0613],
        [-0.2336, -0.4341,  0.3314, -0.2447,  0.0519],
        [-0.2620,  0.0783,  0.3393, -0.1152,  0.1455],
        ...,
        [-0.3957, -0.2433,  0.4255, -0.0991, -0.3537],
        [ 0.0109, -0.2860,  0.1754, -0.2197, -0.1117],
        [-0.0800,  0.3955, -0.0851,  0.4172, -0.1765]], requires_grad=True)

Parameter containing:
tensor([[-0.3481, -0.3933,  0.4385, -0.0338,  0.2791],
        [-0.0962, -0.4130,  0.4257,  0.0193, -0.2658],
        [ 0.3414,  0.2543, -0.0374, -0.0973,  0.2055],
        ...,
        [-0.1643, -0.3190, -0.0191,  0.2594, -0.3957],
        [ 0.2788,  0.3544,  0.3702, -0.2660,  0.3177],
        [ 0.0764,  0.1009,  0.2106, -0.1226, -0.3321]], requires_grad=True)

# Log Experiments

In [38]:
for p in [0.99, 0.9, 0.8, 0.6, 0.5]:
    logs = -np.log([p, 1-p])
    display(p, logs, sum(logs))

0.99

array([0.01005034, 4.60517019])

4.615220521841592

0.9

array([0.10536052, 2.30258509])

2.4079456086518722

0.8

array([0.22314355, 1.60943791])

1.8325814637483102

0.6

array([0.51082562, 0.91629073])

1.4271163556401456

0.5

array([0.69314718, 0.69314718])

1.3862943611198906

In [39]:
random.sample(replay_buffer, k=4)

[SARS(state=array([ 0.01300044,  0.57370025, -0.0537561 , -0.9193304 ], dtype=float32), action=1, reward=1.0, next_state=array([ 0.02447445,  0.769506  , -0.07214271, -1.2284114 ], dtype=float32), t=3, failed=False, limit=False),
 SARS(state=array([ 0.03986457,  0.96547836, -0.09671093, -1.542797  ], dtype=float32), action=1, reward=1.0, next_state=array([ 0.05917414,  1.1616206 , -0.12756687, -1.8640242 ], dtype=float32), t=5, failed=False, limit=False),
 SARS(state=array([ 0.00543996,  0.37802398, -0.04147855, -0.61387724], dtype=float32), action=1, reward=1.0, next_state=array([ 0.01300044,  0.57370025, -0.0537561 , -0.9193304 ], dtype=float32), t=2, failed=False, limit=False),
 SARS(state=array([ 0.02447445,  0.769506  , -0.07214271, -1.2284114 ], dtype=float32), action=1, reward=1.0, next_state=array([ 0.03986457,  0.96547836, -0.09671093, -1.542797  ], dtype=float32), t=4, failed=False, limit=False)]

# Training

![Psudocode](sac_psudocode.png)

Source: https://spinningup.openai.com/en/latest/algorithms/sac.html#pseudocode

In [40]:
def q_min(q1, q2, states, actions_hot):
    
    def f(q):
        state_values = q.forward(states).detach()
        chosen_action_values = torch.sum(state_values * actions_hot, 1)
        return chosen_action_values
        
    return torch.minimum(*map(f, (q1, q2)))

In [41]:
def to_action_probs(probs):
    r = torch.rand(probs.shape)
    a = torch.max(probs * r, 1).indices
    p = torch.sum(nn.functional.one_hot(a) * probs, 1)
    return torch.cat((a.reshape(-1, 1), p.reshape(-1, 1)), 1)

In [42]:
    stats = {}
    # Step 11
    training_batch = random.sample(replay_buffer, k=min(len(replay_buffer), 100))
    # Step 12
    states = torch.tensor(np.array([sars.state for sars in training_batch]), requires_grad=False)
    actions = torch.tensor(np.array([sars.action for sars in training_batch]), requires_grad=False)
    actions_hot = nn.functional.one_hot(actions)
    rewards = torch.tensor(np.array([sars.reward for sars in training_batch]), requires_grad=False)
    next_states = torch.tensor(np.array([sars.next_state for sars in training_batch]), requires_grad=False)
    fails = torch.tensor(np.array([sars.failed for sars in training_batch]), dtype=int, requires_grad=False)
    next_action_probs = policy.policy_network.forward(next_states).detach()
    next_actions_with_probs = to_action_probs(next_action_probs)
    # The next action chosen by the policy network
    next_actions = next_actions_with_probs[:, 0].long()
    next_actions_hot = nn.functional.one_hot(next_actions)
    # The probability of the next action chosen by the policy network
    next_probs = next_actions_with_probs[:, 1]
    sampled_next_action_q_min = q_min(
        policy.q1_target_network,
        policy.q2_target_network,
        next_states,
        next_actions_hot)
    sampled_next_action_prob = torch.sum(next_action_probs * next_actions_hot, 1)
    sampled_next_action_log_prob = torch.log(sampled_next_action_prob)
    y = rewards + GAMMA * (1-fails) * (sampled_next_action_q_min - ALPHA * sampled_next_action_log_prob)
    # Step 13
    for tensor in (states, actions_hot, y):
        assert not tensor.requires_grad
    for qi, q, opt in ((1, policy.q1_network, policy.q1_optimizer),
                       (2, policy.q2_network, policy.q2_optimizer)):
        opt.zero_grad()
        q_state_action = torch.sum(q.forward(states) * actions_hot, 1)
        assert q_state_action.requires_grad
        q_loss = torch.mean((q_state_action - y)**2)
        stats[f'train/q_loss_{qi}'] = q_loss
        q_loss.backward()
        opt.step()
    # Step 14
    policy.policy_optimizer.zero_grad()
    action_probs = policy.policy_network.forward(states)
    assert action_probs.requires_grad
    sampled_action_probs = to_action_probs(action_probs)
    assert sampled_action_probs.requires_grad
    sampled_actions = sampled_action_probs[:, 0].long()
    assert not sampled_actions.requires_grad
    sampled_actions_hot = nn.functional.one_hot(sampled_actions)
    assert not sampled_actions_hot.requires_grad
    sampled_probs = torch.sum(action_probs * sampled_actions_hot, 1)
    assert sampled_probs.requires_grad
    sampled_log_probs = torch.log(sampled_probs)
    assert sampled_log_probs.requires_grad
    sampled_action_q_min = q_min(policy.q1_network, policy.q2_network, states, sampled_actions.reshape((-1, 1)))
    assert not sampled_action_q_min.requires_grad
    policy_loss = -1 * torch.mean(sampled_action_q_min - ALPHA * sampled_log_probs)
    stats['train/policy_loss'] = policy_loss
    policy_loss.backward()
    policy.policy_optimizer.step()
    # Step 15
    polyak_update(policy.q1_target_network, policy.q1_network, tau=TAU)
    polyak_update(policy.q2_target_network, policy.q2_network, tau=TAU)
    stats

{'train/q_loss_1': tensor(0.8530, dtype=torch.float64, grad_fn=<MeanBackward0>),
 'train/q_loss_2': tensor(1.4261, dtype=torch.float64, grad_fn=<MeanBackward0>),
 'train/policy_loss': tensor(-0.1568, grad_fn=<MulBackward0>)}

In [43]:
def train(policy, replay_buffer):
    stats = {}
    # Step 11
    training_batch = random.sample(replay_buffer, k=min(len(replay_buffer), 100))
    # Step 12
    states = torch.tensor(np.array([sars.state for sars in training_batch]), requires_grad=False)
    actions = torch.tensor(np.array([sars.action for sars in training_batch]), requires_grad=False)
    actions_hot = nn.functional.one_hot(actions)
    rewards = torch.tensor(np.array([sars.reward for sars in training_batch]), requires_grad=False)
    next_states = torch.tensor(np.array([sars.next_state for sars in training_batch]), requires_grad=False)
    fails = torch.tensor(np.array([sars.failed for sars in training_batch]), dtype=int, requires_grad=False)
    next_action_probs = policy.policy_network.forward(next_states).detach()
    next_actions_with_probs = to_action_probs(next_action_probs)
    # The next action chosen by the policy network
    next_actions = next_actions_with_probs[:, 0].long()
    next_actions_hot = nn.functional.one_hot(next_actions)
    # The probability of the next action chosen by the policy network
    next_probs = next_actions_with_probs[:, 1]
    sampled_next_action_q_min = q_min(
        policy.q1_target_network,
        policy.q2_target_network,
        next_states,
        next_actions_hot)
    sampled_next_action_prob = torch.sum(next_action_probs * next_actions_hot, 1)
    sampled_next_action_log_prob = torch.log(sampled_next_action_prob)
    y = rewards + GAMMA * (1-fails) * (sampled_next_action_q_min - ALPHA * sampled_next_action_log_prob)
    # Step 13
    for tensor in (states, actions_hot, y):
        assert not tensor.requires_grad
    for qi, q, opt in ((1, policy.q1_network, policy.q1_optimizer),
                       (2, policy.q2_network, policy.q2_optimizer)):
        opt.zero_grad()
        q_state_action = torch.sum(q.forward(states) * actions_hot, 1)
        assert q_state_action.requires_grad
        q_loss = torch.mean((q_state_action - y)**2)
        stats[f'train/q_loss_{qi}'] = q_loss
        q_loss.backward()
        opt.step()
    # Step 14
    policy.policy_optimizer.zero_grad()
    action_probs = policy.policy_network.forward(states)
    assert action_probs.requires_grad
    sampled_action_probs = to_action_probs(action_probs)
    assert sampled_action_probs.requires_grad
    sampled_actions = sampled_action_probs[:, 0].long()
    assert not sampled_actions.requires_grad
    sampled_actions_hot = nn.functional.one_hot(sampled_actions)
    assert not sampled_actions_hot.requires_grad
    sampled_probs = torch.sum(action_probs * sampled_actions_hot, 1)
    assert sampled_probs.requires_grad
    sampled_log_probs = torch.log(sampled_probs)
    assert sampled_log_probs.requires_grad
    sampled_action_q_min = q_min(policy.q1_network, policy.q2_network, states, sampled_actions.reshape((-1, 1)))
    assert not sampled_action_q_min.requires_grad
    policy_loss = -1 * torch.mean(sampled_action_q_min - ALPHA * sampled_log_probs)
    stats['train/policy_loss'] = policy_loss
    policy_loss.backward()
    policy.policy_optimizer.step()
    # Step 15
    polyak_update(policy.q1_target_network, policy.q1_network, tau=TAU)
    polyak_update(policy.q2_target_network, policy.q2_network, tau=TAU)
    return stats

In [44]:
tb_writer = SummaryWriter()

oss = env.observation_space.shape
if len(oss) != 1:
    raise RuntimeError(f'Unknown observation_space.shape: {oss}')
os_len = oss[0]
policy = Policy(os_len, env.action_space.n)

replay_buffer = deque(maxlen=10_000)

for episode in tqdm.tqdm(range(1, 1000+1)):
    episode_reward = run_episode(action, step, env, policy)
    tb_writer.add_scalar('main/episode_reward', episode_reward, episode)
    tb_writer.add_scalar('main/replay_buffer_length', len(replay_buffer), episode)
    for training_iteration in range(1, 100+1):
        stats = train(policy, replay_buffer)
        for stat, value in stats.items():
            tb_writer.add_scalar(stat, value, episode)

  1%|▉                                                                                                                    | 8/1000 [00:03<07:20,  2.25it/s]


KeyboardInterrupt: 