See also: [pytorch cartpole DQN](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)

In [None]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
import random
import matplotlib.pyplot as plt
from collections import deque
from tqdm import tqdm
from dataclasses import dataclass
from itertools import count
import random
import math
from typing import List, Tuple, Dict, Any, Optional, Callable, Union

In [None]:
# make env
import platoonenv
from gym.wrappers import TimeLimit # type: ignore

# envName = "CartPole-v0"
# envName = "MountainCar-v0"
# envName = "Acrobot-v1"

envName = "Platoon-v0"
envName = "Platoon-v1"
# envName = "Platoon-v2"

def make_env(render_mode: Union[str, None] = None) -> gym.Env[np.ndarray, int]:
    env = gym.make(envName, render_mode=render_mode)
    return env

env = make_env(render_mode="human")

In [None]:
print(env.observation_space)

In [None]:
print(env.reset(seed=1))

In [None]:
a = env.reset(seed=1)[0]
b = env.reset(seed=1)[0]
c = env.reset(seed=2)[0]
assert np.array_equal(a.flatten(),b.flatten()), "env should respect reset seed"
assert envName == "Platoon-v0" or not np.array_equal(b,c), "env should respect reset seed"

In [None]:
print(env.action_space)
print(env.action_space.n)

In [None]:
print(env.spec.reward_threshold)

In [None]:
class DQN(nn.Module):
    def __init__(self, obs_shape: Tuple[int,...], hidden_shapes: Tuple[int,...], num_actions: int):
        super().__init__()
        assert len(hidden_shapes) > 0
        net = []
        shapes = (math.prod(obs_shape), ) + hidden_shapes
        for i in range(1, len(shapes)):
            net.append(nn.Linear(shapes[i-1], shapes[i]))
            net.append(nn.BatchNorm1d(shapes[i]))
            net.append(nn.ReLU())
        net.append(nn.Linear(hidden_shapes[-1], num_actions))
        self.net = nn.Sequential(*net)

    def forward(self, x: Tensor) -> Tensor:
        x = x.flatten(start_dim=1)
        x = self.net(x)
        return x

In [None]:
# instantiate models
assert env.observation_space.shape is not None
assert env.action_space.n is not None # type: ignore
# HIDDEN_SHAPES = (128,128,128,128)
HIDDEN_SHAPES = (512,512,512,512)
policy = DQN(
    obs_shape = env.observation_space.shape,
    hidden_shapes = HIDDEN_SHAPES,
    num_actions = env.action_space.n, # type: ignore
)
policy_target = DQN(
    obs_shape = env.observation_space.shape,
    hidden_shapes = HIDDEN_SHAPES,
    num_actions = env.action_space.n, # type: ignore
)
policy_target.eval()
print(policy)

In [None]:
# weight update functions
# from https://github.com/ghliu/pytorch-ddpg/blob/master/util.py#L26
def soft_update(target, source, tau):
        for target_param, param in zip(target.parameters(), source.parameters()):
            ## shouldn't be necessary since we use target networks to calculate loss
            # if isinstance(target_param, torch.nn.parameter.UninitializedParameter):
            #     # target model uninitialized, hard update
            #     target_param.data.copy_(param.data)
            # else:
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

def hard_update(target, source):
    target.load_state_dict(source.state_dict())
    # for target_param, param in zip(target.parameters(), source.parameters()):
    #         target_param.data.copy_(param.data)

In [None]:
hard_update(policy_target, policy)

In [None]:
@dataclass
class Transition:
    observation: Tensor
    action: int
    reward: float
    next_state: Tensor
    finished: bool

@dataclass
class TransitionBatch:
    observations: Tensor
    actions: Tensor
    rewards: Tensor
    next_observations: Tensor
    finished: Tensor
    indices: Tensor
    priorities: Tensor
    weights: Tensor

In [None]:
class PrioritizedReplayBuffer:
    def __init__(
        self,
        capacity: int,
        alpha: float,
        beta: float,
        epsilon: float,
    ):
        self.capacity = capacity
        self.alpha = alpha
        self.beta = beta
        self.epsilon = epsilon
        from sumtree import SumTree
        self.tree = SumTree(capacity)
        self.transitions: List[Union[None, Transition]] = [None for _ in range(capacity)]
        self.current_index = 0
        self.len = 0
    
    def append(self, transition: Transition):
        # from the paper: New transitions arrive without a known TD-error,
        # so we put them at maximal priority in order to guarantee that 
        # all experience is seen at least once.
        priority = 1000 + self.epsilon
        self.tree.update(self.current_index, priority) # type: ignore
        self.transitions[self.current_index] = transition
        self.current_index = (self.current_index + 1) % self.capacity
        self.len = min(self.len + 1, self.capacity)

    def update(self, indices: List[int], priorities: List[float]):
        # from the paper: The TD-error is updated after each minibatch 
        # update, and the priorities are updated accordingly.
        for i, p in zip(indices, priorities):
            self.tree.update(i, (p + self.epsilon) ** self.alpha) # type: ignore

    def sample(self, batch_size: int) -> TransitionBatch:
        # from the paper:
        # The ‘sum-tree’ data structure used here is very similar in spirit to the 
        # array representation of a binary heap. This provides a efficient way of
        # calculating the cumulative sum of priorities, allowing O(logN) 
        # updates and sampling. To sample a minibatch of size k, the 
        # range [0; ptotal] is divided equally into k ranges. Next, a 
        # value is uniformly sampled from each range. Finally the transitions 
        # that correspond to each of these sampled values are retrieved from 
        # the tree. Overhead is similar to rank-based prioritization.
        indices = []
        priorities = []
        transitions = []
        # split the tree into batch_size+1 segments
        space = torch.linspace(0, self.tree.total, batch_size+1).tolist()
        # treat the segments as ranges
        ranges = torch.as_tensor(list(zip(space[:-1], space[1:])))
        # sample a random number in each range
        rand = torch.rand(batch_size)
        cumsums = (ranges[:,0] + (ranges[:,1]-ranges[:,0])*rand).tolist()
        # get the transition corresponding to each random number
        for x in cumsums:
            index, priority = self.tree.get(x)
            trans = self.transitions[index]
            if trans is None:
                index, priority, trans = indices[0], priorities[0], transitions[0]
            indices.append(index)
            priorities.append(priority)
            # use assert to silence python type warnings
            transitions.append(trans)

        # from the paper:
        # P(i) = p_i^alpha / sum_k{p_k^alpha}
        # where p_i = |TD-error| + epsilon
        priorities = torch.as_tensor(priorities, dtype=torch.float32)
        probabilities = priorities / (self.tree.total)
        # Computes the importance sampling (IS) weights for each transition in the batch based on the priorities.
        weights = (self.capacity * probabilities) ** (-self.beta)
        weights /= weights.max()

        return TransitionBatch(
            observations=torch.stack([x.observation for x in transitions]),
            actions=torch.as_tensor([x.action for x in transitions], dtype=torch.int64),
            rewards=torch.as_tensor([x.reward for x in transitions], dtype=torch.float32),
            next_observations=torch.stack([x.next_state for x in transitions]),
            finished=torch.as_tensor([x.finished for x in transitions], dtype=torch.bool),
            indices=torch.as_tensor(indices, dtype=torch.int64),
            priorities=torch.as_tensor(priorities, dtype=torch.float32),
            weights=weights,
        )

    def __len__(self) -> int:
        return self.len

In [None]:
class DequeReplayBuffer:
    def __init__(self, capacity: int) -> None:
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
    
    def append(self, transition: Transition):
        self.buffer.append(transition)
    
    def sample(self, batch_size: int) -> TransitionBatch:
        transitions = random.sample(self.buffer, batch_size)
        return TransitionBatch(
            observations=torch.stack([x.observation for x in transitions]),
            actions=torch.as_tensor([x.action for x in transitions], dtype=torch.int64),
            rewards=torch.as_tensor([x.reward for x in transitions], dtype=torch.float32),
            next_observations=torch.stack([x.next_state for x in transitions]),
            finished=torch.as_tensor([x.finished for x in transitions], dtype=torch.bool),
            indices=None,
            priorities=None,
            weights=None,
        )

    def __len__(self) -> int:
        return len(self.buffer)


In [None]:
# MEMORY_SIZE = 10000
MEMORY_SIZE = 10_000_000
# MEMORY_TYPE = "prioritized"
MEMORY_TYPE = "deque"
if MEMORY_TYPE == "prioritized":
    memory = PrioritizedReplayBuffer(
        capacity=MEMORY_SIZE,
        alpha=0.6,
        beta=0.4,
        epsilon=0.01,
    )
elif MEMORY_TYPE == "deque":
    memory = DequeReplayBuffer(MEMORY_SIZE)
else:
    raise Exception("bad type")

In [None]:
def get_exploration_epsilon(steps_done: int) -> float:
    EPSILON_START = 0.9
    EPSILON_END = 0.05
    EPSILON_DECAY = 10000
    return EPSILON_END + (EPSILON_START - EPSILON_END) * math.exp(-1. * steps_done / EPSILON_DECAY)

In [None]:
# plot epsilon
# %matplotlib ipympl
%matplotlib inline
plt.figure()
plt.style.use('dark_background')
plt.title("exploration epsilon")
plt.plot([get_exploration_epsilon(i) for i in range(1000000) if get_exploration_epsilon(i) >= 0.1])
plt.show()

In [None]:
def get_action(obs: Tensor, epsilon: float) -> int:
    if random.random() < epsilon:
        return env.action_space.sample()
    state_batch = obs.unsqueeze(0)
    policy.eval()
    with torch.no_grad():
        q_values = policy(state_batch)
    return int(q_values.argmax(dim=1).squeeze()) # return action with highest q-value

In [None]:
def test(render=False, seed:int=42) -> Tuple[bool, float]:
    if render:
        env = make_env(render_mode="human")
    else:
        env = make_env()
    assert env.spec.reward_threshold is not None
    obs, info = env.reset(seed=seed)
    obs = torch.as_tensor(obs, dtype=torch.float32)
    episode_reward = 0
    for i in count():
        action = get_action(obs, 0)
        obs, reward, done, trunc, info = env.step(action)
        obs = torch.as_tensor(obs, dtype=torch.float32)
        episode_reward += reward
        if render:
            env.render()
        if done or trunc:
            break
    return episode_reward >= env.spec.reward_threshold, episode_reward

In [None]:
assert not test()[0], "untrained model should probably not be able to solve the environment"
assert np.array_equal(test()[1], test()[1]), "tests should be performed with the same seed"

In [None]:
# step metrics
loss_history: List[float] = []
learning_rate_history: List[float] = []
reward_history: List[float] = []
duration_history: List[int] = []
action_history: List[int] = []
terminal_history: List[bool] = []
# other metrics
test_reward_history: List[Tuple[int, float, float]] = []
last_episode_reward = 0

In [None]:
optimizer = torch.optim.Adam(policy.parameters(), lr=0.001)
# optimizer = torch.optim.RMSprop(policy.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    # mode="max",
    mode="min",
    factor=0.9,
    patience=5000,
    cooldown=5000,
    min_lr=0.00001,
    verbose=True,
)
steps_done = 0

In [None]:
# %load_ext scalene
# use %%scalene to profile a cell


In [None]:
# train loop
TRAIN_EPISODES = 200000
# TRAIN_EPISODES = 100000
solved = False
with tqdm(total=TRAIN_EPISODES, dynamic_ncols=True, ascii=True) as pbar:
    for episode in range(TRAIN_EPISODES):
        if solved: break
        state = torch.as_tensor(env.reset(seed=random.randint(0,100000))[0], dtype=torch.float32)
        episode_reward = 0
        for episode_step in count():
            epsilon = get_exploration_epsilon(steps_done)
            #region exploration
            action = get_action(state, epsilon)
            action_history.append(action)
            next_state, reward, done, trunc, info = env.step(action)
            episode_reward += reward
            reward_history.append(reward)
            terminal_history.append(done)
            if done or trunc:
                duration_history.append(episode_step)
                next_state = torch.zeros(state.shape)
            else:
                next_state = torch.as_tensor(next_state, dtype=torch.float32)                
            memory.append(Transition(
                observation=state,
                action=action,
                reward=reward,
                next_state=next_state,
                finished=done or trunc,
            ))
            state = next_state

            PREVIEW_EPISODE_INTERVAL = 30
            if episode % PREVIEW_EPISODE_INTERVAL == 0:
                env.render()
            #endregion exploration

            #region training
            BATCH_SIZE = 128
            if len(memory) < BATCH_SIZE+4:
                pbar.set_description(f"warmup")
            else:
                policy.train()
                batch = memory.sample(BATCH_SIZE)

                # calculate q values for the actions that were taken
                q_pred = policy(batch.observations).gather(1, batch.actions.unsqueeze(1))

                # calculate q values for next state
                q_next = torch.zeros(BATCH_SIZE)
                non_final = ~batch.finished
                q_next[non_final] = policy_target(batch.next_observations[non_final]).max(dim=1).values.detach()

                # calculate expected q values
                REWARD_GAMMA = 0.99
                q_expected = ((q_next * REWARD_GAMMA) + batch.rewards).unsqueeze(1)

                # calculate loss
                # criterion = torch.nn.SmoothL1Loss()
                # criterion = torch.nn.MSELoss()
                policy_loss = F.mse_loss(q_pred, q_expected)

                if isinstance(memory, PrioritizedReplayBuffer):
                    td_error = (q_expected - q_pred).abs().detach().flatten().tolist()
                    indices = batch.indices.tolist()
                    # update replay buffer priorities
                    memory.update(indices, td_error)
                    # scale loss by weights
                    policy_loss = (batch.weights * policy_loss).mean()


                loss_history.append(policy_loss.item())

                # apply weight update
                optimizer.zero_grad()
                policy_loss.backward()
                for param in policy.parameters():
                    assert param.grad is not None
                    param.grad.data.clamp_(-1, 1) 
                optimizer.step()
                
                # update learning rate
                lr = optimizer.param_groups[0]["lr"]
                learning_rate_history.append(lr)
                scheduler.step(policy_loss)

                ## update target network
                
                SOFT_UPDATE_TAU = 0.001  
                # soft_update(policy_target, policy, SOFT_UPDATE_TAU)
                POLICY_TARGET_UPDATE_INTERVAL = 200 
                if steps_done % POLICY_TARGET_UPDATE_INTERVAL == 0:
                    hard_update(policy_target, policy)

                #region testing
                TEST_INTERVAL = 500
                if steps_done % TEST_INTERVAL == 0:
                    tests = list(zip(*[test() for _ in range(5)]))
                    test_passes: List[bool] = tests[0] # type: ignore :P
                    test_rewards: List[float] = tests[1] # type: ignore :P
                    passed = all(test_passes)
                    test_reward_mean = np.mean(test_rewards).item()
                    test_reward_variance = max(max(test_rewards) - test_reward_mean, test_reward_mean - min(test_rewards))

                    test_reward_history.append((steps_done, test_reward_mean, test_reward_variance))

                    best = max(test_reward_history, key=lambda x: x[1] - x[2])

                    test_reward = f"{test_reward_mean:.3f} \u00b1 {test_reward_variance:.3f}"
                    best_reward = f"{best[1]:.3f} \u00b1 {best[2]:.3f}"
                    print(f"test reward: {test_reward} (best: {best_reward}, goal: {env.spec.reward_threshold})")
                    last_test_reward = test_reward
                    if passed:
                        print("solved!")
                        solved=True
                        break
                #endregion testing

                pbar.set_description(f"policy: {policy_loss.item():09.3f}, reward: {reward:+07.3f} (last episode: {last_episode_reward:.3f}), epsilon: {epsilon:.3f}, lr: {lr:.7f}, steps: {steps_done}")
            #endregion training

            steps_done += 1

            if done or trunc: break

        last_episode_reward = episode_reward
        pbar.update()

In [None]:
batch = memory.sample(10)
# print(batch.actions)
# print(batch.rewards)
# print(batch.indices)
print(batch.priorities)
print(batch.weights)

In [None]:
# plot rewards
plt.figure()
plt.plot([x[1]-x[2] for x in test_reward_history])
plt.title("Test rewards")
plt.show()

In [None]:
# plot durations
plt.figure()
plt.plot(duration_history)
plt.title("Episode durations")
plt.show()

In [None]:
# plot loss
fig, axs = plt.subplots(3)
plt.subplots_adjust(
    left=0.1,
    bottom=0.1,
    right=0.9,
    top=0.9,
    # wspace=0.4,
    hspace=0.4,
)
fig.suptitle("Loss")
axs[0].plot(loss_history)
axs[1].plot(range(len(loss_history))[-300:], loss_history[-300:])
axs[2].hist(np.log10(loss_history))
plt.show()

In [None]:
# investigate actions
import ipywidgets as widgets
plt.figure()
# @widgets.interact(i=(0, len(action_history)-1), window_size=(1, len(action_history)-1))
@widgets.interact(i=(0, len(action_history)-1), window_size=(1, 1000))
def preview_actions(i=len(action_history)-100, window_size=72):
    x = torch.zeros((len(action_history), env.action_space.n + 1))
    x[range(len(action_history)), action_history] = 1
    x[:, 2] = torch.as_tensor(terminal_history, dtype=torch.bool)
    plt.imshow(x[i:i+window_size], aspect="auto")
    plt.show()
    del x

In [None]:
# preview
epsilon = 0
for episode in range(5):
    success, reward = test(render=True)
    print(f"success={success}, reward={reward}")

- security metrics
    - investigate paper: "Deep Reinforcement Learning-Based Defense Strategy Selection"
- why DQN over other algorithms; does it reflect our real world scenario
- plan: results early-mid january