In [1]:
import gymnasium as gym
import torch
from torch import nn
from collections import namedtuple, deque
import itertools
from copy import deepcopy
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import tqdm
from math import log

In [2]:
env = gym.make('LunarLander-v2')

In [3]:
# How much we value entropy / exploration.
ALPHA_TARGET = 0.8 * log(env.action_space.n)            
GAMMA = 1 - 0.01        # How much we value future rewards.
TAU = 0.01              # How much q_target is updated when polyak averaging (step 15).
POLICY_LR = 0.001       # Policy learning rate.
Q_LR = 0.001            # Q learning rate.
ALPHA_LR = 0.001        # ALPHA learning rate.

In [4]:
SARS = namedtuple('SARS', 'state, action, reward, next_state, t, failed, limit')

In [5]:
softmax = nn.Softmax(dim=0)
input = torch.tensor([1, 2, 3], dtype=float)
display(input)
output = softmax(input)
display(output)
sum(output)

tensor([1., 2., 3.], dtype=torch.float64)

tensor([0.0900, 0.2447, 0.6652], dtype=torch.float64)

tensor(1.0000, dtype=torch.float64)

In [6]:
softmax = nn.Softmax(dim=1)
input = torch.tensor([[1, 2, 3], [1, 2, 3], [3, 3, 3]], dtype=float)
display(input)
output = softmax(input)
display(output)
sum(output)

tensor([[1., 2., 3.],
        [1., 2., 3.],
        [3., 3., 3.]], dtype=torch.float64)

tensor([[0.0900, 0.2447, 0.6652],
        [0.0900, 0.2447, 0.6652],
        [0.3333, 0.3333, 0.3333]], dtype=torch.float64)

tensor([0.5134, 0.8228, 1.6638], dtype=torch.float64)

In [7]:
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, 2000),
            nn.ReLU(),
            nn.Linear(2000, 1500),
            nn.ReLU(),
            nn.Linear(1500, output_size)
        )

    def forward(self, x):
        nn_out = self.linear_relu_stack(x)
        return nn.Softmax(dim=1)(nn_out)

    def __call__(self, x):
        raise RuntimeError("Use forward")

In [8]:
policy_network = PolicyNetwork(4, 2)
policy_network

PolicyNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=4, out_features=2000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2000, out_features=1500, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1500, out_features=2, bias=True)
  )
)

In [9]:
torch.cuda.is_available()

True

In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [11]:
mock_states = torch.rand(5, 4)
mock_states

tensor([[0.3078, 0.6021, 0.3211, 0.2887],
        [0.9162, 0.7242, 0.0650, 0.2726],
        [0.2618, 0.4330, 0.6682, 0.1590],
        [0.7153, 0.4364, 0.7828, 0.2448],
        [0.6267, 0.6807, 0.8945, 0.8965]])

In [12]:
policy_network.forward(mock_states)

tensor([[0.5105, 0.4895],
        [0.5101, 0.4899],
        [0.5091, 0.4909],
        [0.5157, 0.4843],
        [0.5201, 0.4799]], grad_fn=<SoftmaxBackward0>)

In [13]:
class QNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, 2000),
            nn.ReLU(),
            nn.Linear(2000, 1500),
            nn.ReLU(),
            nn.Linear(1500, output_size)
        )

    def forward(self, x):
        nn_out = self.linear_relu_stack(x)
        return nn_out
    
    def __call__(self, x):
        raise RuntimeError("Use forward")

In [14]:
q_network = QNetwork(4, 2)
q_network

QNetwork(
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=4, out_features=2000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2000, out_features=1500, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1500, out_features=2, bias=True)
  )
)

In [15]:
q_network.forward(mock_states)

tensor([[ 0.0883, -0.1147],
        [ 0.1121, -0.1403],
        [ 0.1029, -0.1290],
        [ 0.1185, -0.1729],
        [ 0.1125, -0.2135]], grad_fn=<AddmmBackward0>)

In [16]:
def run_episode(action_f, step_f, env, policy):
    episode_reward = 0
    s, info = env.reset()
    for t in itertools.count(start=1):
        a = action_f(policy, s)
        next_state, reward, failed, limit, info = env.step(a)
        episode_reward += reward
        assert t <= env._max_episode_steps
        assert not (failed and limit)
        step_f(s, a, reward, next_state, t, failed, limit)
        if failed or limit:
            break
        s = next_state
    return episode_reward

In [17]:
class Policy:
    def __init__(self, env_state_size, env_action_space_size):
        self.policy_network = PolicyNetwork(env_state_size, env_action_space_size)
        self.q1_network = QNetwork(env_state_size, env_action_space_size)
        self.q2_network = QNetwork(env_state_size, env_action_space_size)
        self.q1_target_network = deepcopy(self.q1_network)
        self.q2_target_network = deepcopy(self.q2_network)
        self.alpha = torch.tensor(ALPHA_TARGET, dtype=float, requires_grad=True)
        self.policy_network.to(device)
        self.q1_network.to(device)
        self.q2_network.to(device)
        self.q1_target_network.to(device)
        self.q2_target_network.to(device)
        self.alpha.to(device)
        self.reset_optimizers()

    def reset_optimizers(self):
        self.policy_optimizer = torch.optim.SGD(self.policy_network.parameters(), lr=POLICY_LR)
        self.q1_optimizer = torch.optim.SGD(self.q1_network.parameters(), lr=Q_LR)
        self.q2_optimizer = torch.optim.SGD(self.q2_network.parameters(), lr=Q_LR)

    @property
    def alpha_dc(self):
        """Alpha, (D)etached and (C)lamped"""
        return self.alpha.detach().clamp(min=0)

In [18]:
t = torch.tensor(1, dtype=float, requires_grad=True)
display(t)
torch.optim.SGD([t], lr=ALPHA_LR)

tensor(1., dtype=torch.float64, requires_grad=True)

SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.001
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)

In [19]:
replay_buffer = deque(maxlen=30_000)

In [20]:
env.action_space.n

4

In [21]:
env.observation_space.shape

(8,)

In [22]:
oss = env.observation_space.shape
if len(oss) != 1:
    raise RuntimeError(f'Unknown observation_space.shape: {oss}')
os_len = oss[0]
policy = Policy(os_len, env.action_space.n)

In [23]:
s, info = env.reset()
s

array([ 0.00466385,  1.4123211 ,  0.4723712 ,  0.06226536, -0.00539734,
       -0.10699912,  0.        ,  0.        ], dtype=float32)

In [24]:
s = torch.tensor(s).reshape((1, -1))
s

tensor([[ 0.0047,  1.4123,  0.4724,  0.0623, -0.0054, -0.1070,  0.0000,  0.0000]])

In [25]:
policy_output = policy.policy_network.forward(s.to(device))
policy_output

tensor([[0.2569, 0.2318, 0.2781, 0.2332]], device='cuda:0',
       grad_fn=<SoftmaxBackward0>)

In [26]:
action_weights = policy_output.reshape((-1,)).tolist()

In [27]:
list(range(4))

[0, 1, 2, 3]

In [28]:
action = random.choices(range(len(action_weights)), weights=action_weights)[0]
action

3

In [29]:
def action(policy, s):
    tensor_s = torch.tensor(s).reshape((1, -1)).to(device)
    action_weights = policy.policy_network.forward(tensor_s).reshape((-1,)).tolist()
    action = random.choices(range(len(action_weights)), weights=action_weights)[0]
    return action

def step(initial_s, a, r, next_s, t, failed, limit):
    replay_buffer.append(SARS(initial_s, a, r, next_s, t, failed, limit))

In [30]:
run_episode(action, step, env, policy)

  if not isinstance(terminated, (bool, np.bool8)):


-91.20855692557728

In [31]:
len(replay_buffer)

70

In [32]:
replay_buffer

deque([SARS(state=array([-0.00549459,  1.4016852 , -0.5565525 , -0.41045275,  0.0063736 ,
               0.12606737,  0.        ,  0.        ], dtype=float32), action=2, reward=2.266187979900633, next_state=array([-0.01100483,  1.3934011 , -0.5572752 , -0.36822054,  0.01256057,
               0.12375202,  0.        ,  0.        ], dtype=float32), t=1, failed=False, limit=False),
       SARS(state=array([-0.01100483,  1.3934011 , -0.5572752 , -0.36822054,  0.01256057,
               0.12375202,  0.        ,  0.        ], dtype=float32), action=3, reward=-0.44508311116174926, next_state=array([-0.01644936,  1.3845133 , -0.5490331 , -0.39506158,  0.01708808,
               0.0905586 ,  0.        ,  0.        ], dtype=float32), t=2, failed=False, limit=False),
       SARS(state=array([-0.01644936,  1.3845133 , -0.5490331 , -0.39506158,  0.01708808,
               0.0905586 ,  0.        ,  0.        ], dtype=float32), action=2, reward=0.8300302230607428, next_state=array([-0.02206526,  1.37

# Polyak Averaging

In [33]:
test_parameter_1 = next(policy.policy_network.named_parameters())[1]
test_parameter_1

Parameter containing:
tensor([[ 0.2055, -0.2319,  0.2028,  ...,  0.0922, -0.2282, -0.3254],
        [ 0.2028, -0.3522, -0.3411,  ...,  0.1852,  0.2460, -0.2618],
        [-0.1746, -0.3408, -0.1687,  ...,  0.2524,  0.1196,  0.1263],
        ...,
        [-0.1459, -0.1217, -0.0412,  ..., -0.1848,  0.2856, -0.3305],
        [ 0.1725,  0.1390, -0.0070,  ..., -0.1292,  0.0657, -0.1553],
        [ 0.0756,  0.2674,  0.2297,  ..., -0.1144, -0.0216,  0.1884]],
       device='cuda:0', requires_grad=True)

In [34]:
test_parameter_2 = test_parameter_1 * 0 + 0.0128
test_parameter_2

tensor([[0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        ...,
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128],
        [0.0128, 0.0128, 0.0128,  ..., 0.0128, 0.0128, 0.0128]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [35]:
test_parameter_1 * 0.9 + test_parameter_2 * 0.1

tensor([[ 0.1862, -0.2074,  0.1838,  ...,  0.0843, -0.2041, -0.2916],
        [ 0.1838, -0.3157, -0.3057,  ...,  0.1679,  0.2227, -0.2343],
        [-0.1559, -0.3054, -0.1505,  ...,  0.2284,  0.1089,  0.1150],
        ...,
        [-0.1301, -0.1083, -0.0358,  ..., -0.1650,  0.2583, -0.2962],
        [ 0.1565,  0.1264, -0.0050,  ..., -0.1150,  0.0604, -0.1385],
        [ 0.0694,  0.2419,  0.2080,  ..., -0.1017, -0.0181,  0.1708]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [36]:
def polyak_update(network_to_update, target_network, tau=0.001):
    with torch.no_grad():
        for to_update, target in zip(network_to_update.parameters(), target_network.parameters()):
            to_update *= 1-tau
            to_update += target * tau

In [37]:
test_network_1 = QNetwork(5, 3)
test_network_2 = QNetwork(5, 3)
display(list(test_network_1.parameters())[0])
display(list(test_network_2.parameters())[0])
polyak_update(test_network_2, test_network_1, 0.1)
display(list(test_network_1.parameters())[0])
display(list(test_network_2.parameters())[0])

Parameter containing:
tensor([[-1.1490e-01, -1.1928e-02,  4.1765e-01, -3.2731e-01,  1.8876e-01],
        [-2.3200e-01, -4.8993e-02,  4.3404e-01,  1.4395e-01,  1.9092e-01],
        [-1.0534e-04, -3.0931e-01, -1.4797e-01, -2.4361e-01, -1.5164e-01],
        ...,
        [-4.4118e-01, -1.2890e-01, -3.0154e-02, -3.3243e-01, -1.9251e-01],
        [-1.9203e-01,  3.5193e-01,  7.4533e-02,  2.3991e-01, -8.3312e-02],
        [-1.2111e-01,  2.6557e-01,  2.3603e-01,  7.0378e-02,  1.5150e-01]],
       requires_grad=True)

Parameter containing:
tensor([[-0.0685, -0.3881, -0.4062, -0.1459,  0.2815],
        [ 0.3961, -0.4252, -0.1942,  0.4413, -0.0097],
        [-0.2315, -0.2041, -0.0810,  0.3366, -0.2691],
        ...,
        [ 0.4086,  0.1911,  0.2079, -0.1103, -0.1733],
        [ 0.1261,  0.2942,  0.1815, -0.0397,  0.1583],
        [ 0.1458,  0.2484, -0.2127, -0.1054, -0.4002]], requires_grad=True)

Parameter containing:
tensor([[-1.1490e-01, -1.1928e-02,  4.1765e-01, -3.2731e-01,  1.8876e-01],
        [-2.3200e-01, -4.8993e-02,  4.3404e-01,  1.4395e-01,  1.9092e-01],
        [-1.0534e-04, -3.0931e-01, -1.4797e-01, -2.4361e-01, -1.5164e-01],
        ...,
        [-4.4118e-01, -1.2890e-01, -3.0154e-02, -3.3243e-01, -1.9251e-01],
        [-1.9203e-01,  3.5193e-01,  7.4533e-02,  2.3991e-01, -8.3312e-02],
        [-1.2111e-01,  2.6557e-01,  2.3603e-01,  7.0378e-02,  1.5150e-01]],
       requires_grad=True)

Parameter containing:
tensor([[-0.0731, -0.3505, -0.3238, -0.1641,  0.2722],
        [ 0.3333, -0.3876, -0.1313,  0.4115,  0.0103],
        [-0.2084, -0.2146, -0.0877,  0.2786, -0.2574],
        ...,
        [ 0.3236,  0.1591,  0.1841, -0.1326, -0.1752],
        [ 0.0943,  0.2999,  0.1708, -0.0117,  0.1342],
        [ 0.1191,  0.2501, -0.1678, -0.0878, -0.3450]], requires_grad=True)

# Log Experiments

In [38]:
for p in [0.99, 0.9, 0.8, 0.6, 0.5]:
    logs = -np.log([p, 1-p])
    display(p, logs, sum(logs))

0.99

array([0.01005034, 4.60517019])

4.615220521841592

0.9

array([0.10536052, 2.30258509])

2.4079456086518722

0.8

array([0.22314355, 1.60943791])

1.8325814637483102

0.6

array([0.51082562, 0.91629073])

1.4271163556401456

0.5

array([0.69314718, 0.69314718])

1.3862943611198906

In [39]:
random.sample(replay_buffer, k=4)

[SARS(state=array([-0.42253333,  0.00773657, -0.63417655, -1.4615089 , -0.00227488,
        -0.21600907,  0.        ,  0.        ], dtype=float32), action=3, reward=5.9156894854490085, next_state=array([-0.42889762, -0.02575163, -0.62327385, -1.4884341 , -0.01525886,
        -0.25967962,  1.        ,  0.        ], dtype=float32), t=69, failed=False, limit=False),
 SARS(state=array([-0.08050861,  1.268313  , -0.6130634 , -0.52033454,  0.10132137,
         0.17359746,  0.        ,  0.        ], dtype=float32), action=0, reward=-1.4346597211933272, next_state=array([-0.08655224,  1.2560189 , -0.61308604, -0.5470208 ,  0.1099989 ,
         0.17356627,  0.        ,  0.        ], dtype=float32), t=14, failed=False, limit=False),
 SARS(state=array([-0.15680256,  1.0631496 , -0.5811101 , -0.8218586 ,  0.14920697,
        -0.05871984,  0.        ,  0.        ], dtype=float32), action=0, reward=-0.09461562421537906, next_state=array([-0.16264305,  1.044051  , -0.5811099 , -0.84852594,  0.1462709

# Training

![Psudocode](sac_psudocode.png)

Source: https://spinningup.openai.com/en/latest/algorithms/sac.html#pseudocode

In [40]:
def q_min(q1, q2, states):
    
    def f(q):
        state_values = q.forward(states.to(device)).detach()
        return state_values
        
    return torch.minimum(*map(f, (q1, q2)))

In [41]:
def train(policy, replay_buffer):
    stats = {}
    # Step 11
    training_batch = random.sample(replay_buffer, k=min(len(replay_buffer), 100))
    # Prep
    states = torch.tensor(np.array([sars.state for sars in training_batch]), requires_grad=False).to(device)
    actions = torch.tensor(np.array([sars.action for sars in training_batch]), requires_grad=False).to(device)
    actions_hot = nn.functional.one_hot(actions, env.action_space.n).to(device)
    rewards = torch.tensor(np.array([sars.reward for sars in training_batch]), requires_grad=False).to(device)
    next_states = torch.tensor(np.array([sars.next_state for sars in training_batch]), requires_grad=False).to(device)
    fails = torch.tensor(np.array([sars.failed for sars in training_batch]), dtype=int, requires_grad=False).to(device)
    # Step 12
    next_action_probs = policy.policy_network.forward(next_states.to(device)).detach()
    assert not next_action_probs.requires_grad
    next_states_q_min = q_min(policy.q1_target_network, policy.q2_target_network, next_states)
    assert not next_states_q_min.requires_grad
    next_actions_q_min = torch.sum(next_states_q_min * next_action_probs, 1)
    assert not next_actions_q_min.requires_grad
    next_actions_entropy = torch.sum(next_action_probs * torch.log(next_action_probs), 1)
    assert not next_actions_entropy.requires_grad
    y = rewards + GAMMA * (1-fails) * (next_actions_q_min - policy.alpha_dc * next_actions_entropy)
    assert not y.requires_grad
    # Step 13
    for qi, q, opt in ((1, policy.q1_network, policy.q1_optimizer),
                       (2, policy.q2_network, policy.q2_optimizer)):
        assert not states.requires_grad
        assert not actions_hot.requires_grad
        q_state_action = torch.sum(q.forward(states.to(device)) * actions_hot, 1)
        assert q_state_action.requires_grad
        q_loss = torch.mean((q_state_action - y)**2)
        assert q_loss.requires_grad
        stats[f'train/q_loss_{qi}'] = q_loss.detach()
        assert q_loss.requires_grad
        opt.zero_grad()
        q_loss.backward()
        opt.step()
    # Step 14
    action_probs = policy.policy_network.forward(states)
    assert action_probs.requires_grad
    states_q_min = q_min(policy.q1_network, policy.q2_network, states)
    assert not states_q_min.requires_grad
    actions_q_min = torch.sum(states_q_min * action_probs, 1)
    assert actions_q_min.requires_grad
    actions_entropy = torch.sum(action_probs * torch.log(action_probs), 1)
    assert actions_entropy.requires_grad
    stats['train/entropy'] = -actions_entropy.detach().mean()
    stats['train/entropy_percent'] = -actions_entropy.detach().mean() / log(action_probs.shape[1])
    policy_loss = -1 * torch.mean(actions_q_min - policy.alpha_dc * actions_entropy)
    assert policy_loss.requires_grad
    stats['train/policy_loss'] = policy_loss.detach()
    policy.policy_optimizer.zero_grad()
    policy_loss.backward()
    policy.policy_optimizer.step()
    # Alpha Adjust
    assert policy.alpha.requires_grad
    assert action_probs.requires_grad
    policy.alpha.requires_grad_(True)
    alpha_optimizer = torch.optim.SGD([policy.alpha], lr=ALPHA_LR)
    alpha_loss = -1 * policy.alpha * (actions_entropy.detach().mean() + ALPHA_TARGET)
    assert alpha_loss.requires_grad
    stats['train/alpha_loss'] = alpha_loss.detach()
    alpha_optimizer.zero_grad()
    alpha_loss.backward()
    alpha_optimizer.step()
    policy.alpha.data.clamp_(min=0)
    stats['train/alpha'] = policy.alpha.detach()
    # Step 15
    polyak_update(policy.q1_target_network, policy.q1_network, tau=TAU)
    polyak_update(policy.q2_target_network, policy.q2_network, tau=TAU)
    return stats

In [42]:
tb_writer = SummaryWriter()

oss = env.observation_space.shape
if len(oss) != 1:
    raise RuntimeError(f'Unknown observation_space.shape: {oss}')
os_len = oss[0]
policy = Policy(os_len, env.action_space.n)

replay_buffer = deque(maxlen=30_000)

def action(policy, s):
    tensor_s = torch.tensor(s).reshape((1, -1))
    action_weights = policy.policy_network.forward(tensor_s.to(device)).reshape((-1,)).tolist()
    action = random.choices(range(len(action_weights)), weights=action_weights)[0]
    return action

def step(initial_s, a, r, next_s, t, failed, limit):
    replay_buffer.append(SARS(initial_s, a, r, next_s, t, failed, limit))

n_episodes = 2000
for episode in tqdm.tqdm(range(1, n_episodes+1)):
    episode_reward = run_episode(action, step, env, policy)
    tb_writer.add_scalar('main/episode_reward', episode_reward, episode)
    tb_writer.add_scalar('main/replay_buffer_length', len(replay_buffer), episode)
    tb_writer.add_scalar('main/alpha_target', ALPHA_TARGET, episode)
    tb_writer.add_scalar('main/alpha_target_percent', ALPHA_TARGET / log(env.action_space.n), episode)
    policy.reset_optimizers()
    for training_iteration in range(1, 200+1):
        stats = train(policy, replay_buffer)
        for stat, value in stats.items():
            tb_writer.add_scalar(stat, value, episode)
    if episode > n_episodes / 4:
        ALPHA_TARGET *= 0.5**(1/800)

100%|████████████████████████████████████████████████████████████████████████████| 2000/2000 [43:49<00:00,  1.31s/it]


In [43]:
env = gym.make('LunarLander-v2', render_mode='human')

def render_only(initial_s, a, r, next_s, t, failed, limit):
    pass

scores = []
for episode in range(1, 10+1):
    scores.append(run_episode(action, render_only, env, policy))

print(f'Mean score: {sum(scores) / len(scores):.0f}')

env.close()

Mean score: 279
