In [196]:
from typing import List
import math
import random
from collections import namedtuple, deque
from itertools import count

import gymnasium as gym

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np

In [197]:
env = gym.make('CartPole-v1')

In [198]:
device = torch.device('cpu')

In [199]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

In [200]:
class DQN(nn.Module):

    '''
    A fully connected neural network representing the policy (Q)
    and the target networks

    Parameters
    ==========
    dim_state: The dimensionality of the states
    n_actions: The number of possible actions from each state
    '''

    def __init__(self, dim_state: int, n_actions: int) -> None:
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(dim_state, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [201]:
BATCH_SIZE = 128
GAMMA = 0.99
# The initial porbability of uniformly selecting an action (not greedy)
EPS_START = 0.9
# The final probability of uniformly selecting an action (not greedy)
EPS_END = 0.05
# The decay size on the probability of uniformly selecting an action
EPS_DECAY = 1000
# The extent to which the target network parameters should be updated (taken from policy network)
TAU = 0.005
# Learning rate for optimization
LR = 1e-4

In [202]:
# Number of actions possible from each state
n_actions = env.action_space.n
# Number of dimensions (representation) for each state
state, info = env.reset()
dim_state = len(state)

In [203]:
policy_net = DQN(dim_state, n_actions).to(device)
# Targets for the policy net are provided by the target net
target_net = DQN(dim_state, n_actions).to(device)
# Copy the parameters from the policy net into the target net
target_net.load_state_dict(policy_net.state_dict()) 

<All keys matched successfully>

In [204]:
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)

In [205]:
steps_done = 0

In [206]:
def select_action(state: torch.Tensor) -> torch.Tensor:
    '''
    Performs epsilon-greedy action selection given a state
    '''
    # The steps_done is in the global scope
    global steps_done
    # Generate randomly a number between 0 and 1 
    sample = random.random()
    # Probability for the epsilon-greedy selection
    # It decays as the number of steps grow
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1 * steps_done / EPS_DECAY)
    steps_done += 1
    # Returns a tensor of dim (1, 1) (tensor([[1]]), for example)
    if sample > eps_threshold:
        with torch.no_grad():
            # t.max(1) (t stands for a tensor) returns a tensor
            # where the first column is the largest column found for 
            # each row in t and the second column is the index
            # of the column at which the maximum value happened. 
            return policy_net(state).max(1).indices.view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)

In [208]:
class Node(object):
    '''Generic Node class. A node is a basic unit of a data structure.
       Used in the implementation of a prioritized replay buffer.
    '''
    def __init__(self, key, value):
        # Holds the TD error
        self.key = key
        # Holds the experience tuple
        self.value = value
    def update_key_and_value(self, new_key, new_value):
        '''Updates the key and value at the same time'''
        self.update_key(new_key)
        self.update_value(new_value)
    def update_key(self, new_key):
        '''Assinging new TD error to the experience'''
        self.key = new_key
    def update_value(self, new_value):
        '''Overwriting when the experience with a new one'''
        self.value = new_value
    def __eq__(self, other):
        '''Two nodes are equal if an only if their keys and values are equal'''
        return self.key == other.key and self.value == other.value

In [209]:
class MaxHeap(object):
    
    def __init__(self, max_size, dimension_of_value_attribute, default_key_to_use):
        # Max size of the deque (max_size * 4 + 1, for the heap) 
        self.max_size = max_size
        # Dimension of the tuple that represents an experience
        self.dimension_of_value_attribute = dimension_of_value_attribute
        # The default key for the nodes (during initialization)
        self.default_key_to_use = default_key_to_use
        # Initializes the heap with the attributes set above
        self.heap = self.initialize_heap()
    
    def initialize_heap(self):
        '''Initialize a heap of Nodes of length self.max_size * 4 + 1'''
        heap = np.array([Node(self.default_key_to_use, tuple([None for _ in range(self.dimension_of_value_attribute)]))
                         for _ in range(self.max_size * 4 + 1)])
        # The first element in the heap cannot be swapped 
        heap[0] = Node(float('inf'), tuple([None for _ in range(self.dimension_of_value_attribute)]))
        return heap
    
    def update_heap_element(self, heap_index, new_element):
        '''Updates the node at heap_index'''
        self.heap[heap_index] = new_element

    def swap_heap_elements(self, index1, index2):
        '''Swaps the position of two heap elements'''
        self.heap[index1], self.heap[index2] = self.heap[index2], self.heap[index1]
    
    def calculate_index_of_biggest_child(self, heap_index_changed):
        '''heap_index_changed is the index at which there was a change.
        
           Get the index of its biggest child using the following formula 
                (1) Left child's index = heap_index_changed * 2
                (2) Right child's index = heap_index_changed * 2 + 1    
        '''
        left_child = self.heap[int(heap_index_changed * 2)]
        right_child = self.heap[int(heap_index_changed * 2) + 1]
        # Get the biggest child's index (in terms of the TD error)
        if left_child.key > right_child.key:
            return heap_index_changed * 2
        else:
            return heap_index_changed * 2 + 1
    
    # The element at position 1 has the highest key (TD error)
    def give_max_key(self):
        '''Get the maximum TD error from the heap'''
        return self.heap[1].key

    def reorganize_heap(self, heap_index_changed):
        '''heap_index_changed is the index at which there was a change.
           When there is a change at an index, we need to re-sort the array.
           Calling the function with heap_index_changed = 1 results in no change.  
        '''
        # New TD error for the experience stored at heap_index_changed
        node_key = self.heap[heap_index_changed].key
        # Get the parent index of the heap_index_changed
        parent_index = int(heap_index_changed / 2)
        
        if node_key > self.heap[parent_index].key:
            # Change the element at heap_index_changed with the element at position parent_index   
            self.swap_heap_elements(heap_index_changed, parent_index)
            # Recall the function with parent_index (upward direction)
            self.reorganize_heap(parent_index)
        else:
            # Get the index of the biggest child (in terms of the Td error)
            biggest_child_index = self.calculate_index_of_biggest_child(heap_index_changed)
            if node_key < self.heap[biggest_child_index].key:
                # Change the element at heap_index_changed with the element at position biggest_child_index
                self.swap_heap_elements(heap_index_changed, biggest_child_index)
                # Recall the function with biggest_child_index (downward direction)
                self.reorganize_heap(biggest_child_index)

    def update_element_and_reorganize_heap(self, heap_index_changed, new_element):
        '''Updates the element at the specified heap_index_changed with a new node.'''
        self.update_heap_element(heap_index_changed, new_element)
        self.reorganize_heap(heap_index_changed)

In [210]:
class Deque(object):
    '''Generic deque object'''
    
    def __init__(self, max_size, dimension_of_value_attribute, default_key_to_use):
        # Maximum size of the queue 
        self.max_size = max_size
        # Dimension of the value (a tuple) for the nodes in the queue
        self.dimension_of_value_attribute = dimension_of_value_attribute
        # Initialize the queue
        self.deque = self.initialize_deque()
        # The next index in the queue to be modified 
        self.deque_index_to_overwrite_next = 0
        # If we are at the end of the queue
        self.reached_max_capacity = False
        self.number_experiences_in_deque = 0
        # The default key for the nodes (during initialization)
        self.default_key_to_use = default_key_to_use

    def initialize_deque(self):
        '''Initializes a queue of nodes of length self.max_size'''
        deque = np.array([Node(self.default_key_to_use, tuple([None for _ in range(self.dimension_of_value_attribute)]))
                          for _ in range(self.max_size)])
        return deque
    
    def update_deque_node_key(self, index, new_key):
        '''Updates a node's key at the specified index'''
        self.deque[index].update_key(new_key)
    
    def update_deque_node_value(self, index, new_value):
        '''Updates a node's value at the specified index'''
        self.deque[index].update_value(new_value)
    
    def update_deque_node_key_and_value(self, index, new_key, new_value):
        '''Updates a node's key and value at the specified index'''
        self.update_deque_node_key(index, new_key)
        self.update_deque_node_value(index, new_value)
    
    def update_deque_index_to_overwrite_next(self):
        '''Increments the attribute deque_index_to_overwrite_next by one
           if we have not reached the end of the queue; otherwise,
           sets it equal to zero, indicating a return to the bottom.  
        '''
        # The index starts from 0 and ends at self.max_size - 1
        if self.deque_index_to_overwrite_next < self.max_size - 1:
            self.deque_index_to_overwrite_next += 1
        else:
            self.reached_max_capacity = True
            # Go back to the bottom
            self.deque_index_to_overwrite_next = 0
    
    def add_element_to_deque(self, new_key, new_value):
        '''Adds a new element (a node) to the queue'''
        self.update_deque_node_key_and_value(self.deque_index_to_overwrite_next, new_key, new_value)
        self.update_number_experiences_in_deque()
        self.update_deque_index_to_overwrite_next()

In [225]:
class PrioritizedReplayBuffer(MaxHeap, Deque):
    
    def __init__(self, hyperparameters, seed=0):
        
        MaxHeap.__init__(self, hyperparameters['buffer_size'], dimension_of_value_attribute=4, default_key_to_use=0)
        Deque.__init__(self, hyperparameters['buffer_size'], dimension_of_value_attribute=4, default_key_to_use=0)
        
        np.random.seed(seed)

        self.deques_td_errors = self.initialize_td_errors_array()
        # Remember: index 0 in the heap cannot be swapped or modified
        self.heap_index_to_overwrite_next = 1
        self.number_experiences_in_deque = 0
        self.adapted_overall_sum_of_td_errors = 0

        # The degree to which prioritiztion is done (the closer to 1 the more)
        self.alpha = hyperparameters['alpha_prioritized_replay']
        # The degree to which correction for the bias is carried out 
        self.beta = hyperparameters['beta_prioritized_replay']
        # Epsilon in the definition of p_{i}
        self.incremental_td_error = hyperparameters['incremental_td_error']
        # The size of the sample used to create a batch of experiences
        self.batch_size = hyperparameters['batch_size']

        # Indexes in the heap for which there was a change in the TD error
        self.heap_indexes_to_update_td_error_for = None

        # Set the device to cpu
        self.device = 'cpu'

    def initialize_td_errors_array(self):
        return np.zeros(self.max_size)

    def add_experience(self, td_error, state, action, next_state, reward):
        # Get the TD error of the current element in the buffer (to be replaced) and
        # remove it from the overall sum of TD errors and add the TD error of the 
        # new experience.
        self.update_overall_sum(td_error, self.deque[self.deque_index_to_overwrite_next].key)
        self.update_deque_and_deque_td_errors(td_error, state, action, next_state, reward)
        self.update_heap_and_heap_index_to_overwrite()
        self.update_number_experiences_in_deque()
        self.update_deque_index_to_overwrite_next()

    def update_overall_sum(self, new_td_error, old_td_error):
        '''Updates the overall sum of td values in the buffer'''
        self.adapted_overall_sum_of_td_errors += new_td_error - old_td_error

    def update_deque_and_deque_td_errors(self, td_error, state, action, next_state, reward):
        self.deques_td_errors[self.deque_index_to_overwrite_next] = td_error
        self.add_element_to_deque(td_error, Transition(state, action, next_state, reward))
    
    def add_element_to_deque(self, new_key, new_value):
        self.update_deque_node_key_and_value(self.deque_index_to_overwrite_next, new_key, new_value)
    
    def update_heap_and_heap_index_to_overwrite(self):
        if not self.reached_max_capacity:
            # The heap_index is always one step ahead of the deque index (under the preceding if condition) 
            self.update_heap_element(self.heap_index_to_overwrite_next, self.deque[self.deque_index_to_overwrite_next])
            # Add a new attribute to the node at deque_index_to_overwrite_next that shows its index in heap  
            self.deque[self.deque_index_to_overwrite_next].heap_index = self.heap_index_to_overwrite_next
            # Increment by one the heap index (1 (start) -> 2 -> 3 ...)
            self.update_heap_index_to_overwrite_next()
        # When the capacity is full, the element located at index 0 of deque
        # will be modified with the data of the new experience. We do not
        # create a new instance of the Node object, but we only modify the 
        # key and the value of the already existing instance. This will also 
        # modify the data of the corresponding node in the heap. Therefore, we 
        # have the heap index (assigned before) and we need to reorganize the
        # heap using that. 
        heap_index_change = self.deque[self.deque_index_to_overwrite_next].heap_index
        self.reorganize_heap(heap_index_change)
    
    def update_number_experiences_in_deque(self):
        '''Increments the attribute number_experiences_in_deque by one
           if we have not reached the end of the queue. 
        '''
        if not self.reached_max_capacity:
            self.number_experiences_in_deque += 1

    def update_heap_index_to_overwrite_next(self):
        self.heap_index_to_overwrite_next += 1
    
    def swap_heap_elements(self, index1, index2):
        self.heap[index1], self.heap[index2] = self.heap[index2], self.heap[index1]
        # Update the heap index for the nodes (will be also reflected in their copies in deque)
        self.heap[index1].heap_index = index1
        self.heap[index2].heap_index = index2

    def sample(self):
        # Get a sample from the experiences in the buffer and return the index of the selected ones
        experiences, deque_sample_indexes = self.pick_experiences_based_on_proportional_td_error()
        transitions = self.separate_out_data_types(experiences)
        self.deque_sample_indexes_to_update_td_error_for = deque_sample_indexes
        importance_sampling_weights = self.calculate_importance_sampling_weights(experiences)
        return transitions, importance_sampling_weights
    
    def pick_experiences_based_on_proportional_td_error(self):
        # P(i) ~ p_{i}/sum(p_{j})
        probabilities = self.deques_td_errors / self.give_adapted_sum_of_td_errors()
        deque_sample_indexes = np.random.choice(range(len(self.deques_td_errors)), size=self.batch_size, replace=False, p=probabilities)
        experiences = self.deque[deque_sample_indexes]
        # Experiences sampled and their indexes in the deque
        return experiences, deque_sample_indexes

    def separate_out_data_types(self, experiences):
        transitions = [e.value for e in experiences] 
        return transitions
    
    def calculate_importance_sampling_weights(self, experiences):
        '''Calculate the w_{i} for i in the sampled batch'''
        td_errors = [experience.key for experience in experiences]
        importance_sampling_weights = [((1 / self.number_experiences_in_deque) * (self.give_adapted_sum_of_td_errors() / td_error)) ** self.beta for td_error in td_errors]
        sample_max_importance_weight = max(importance_sampling_weights)
        importance_sampling_weights = [is_weight / sample_max_importance_weight for is_weight in importance_sampling_weights]
        # A 1-d tensor that will be multiplied to the 1-d delta tensor
        importance_sampling_weights = torch.tensor(importance_sampling_weights).float().to(self.device)
        return importance_sampling_weights

    def update_td_errors(self, td_errors):
        '''Updates the TD error of the replayed experiences.'''
        for raw_td_error, deque_index in zip(td_errors, self.deque_sample_indexes_to_update_td_error_for):
            td_error =  (abs(raw_td_error) + self.incremental_td_error) ** self.alpha
            corresponding_heap_index = self.deque[deque_index].heap_index
            self.update_overall_sum(td_error, self.heap[corresponding_heap_index].key)
            self.heap[corresponding_heap_index].key = td_error
            self.reorganize_heap(corresponding_heap_index)
            self.deques_td_errors[deque_index] = td_error

    def give_max_td_error(self):
        return self.give_max_key()

    def give_adapted_sum_of_td_errors(self):
        return self.adapted_overall_sum_of_td_errors

    def __len__(self):
        return self.number_experiences_in_deque

In [226]:
hyperparameters = {
    'batch_size': 256,
    'buffer_size': 40000,
    'alpha_prioritized_replay': 0.6,
    'beta_prioritized_replay': 0.1,
    'incremental_td_error': 1e-8,
}

# Prioritized replay memory
prm = PrioritizedReplayBuffer(hyperparameters)

In [227]:
def optimize_model():
    '''
    Double DQN with Prioritized sampling (Schaul et al, 2016)
    '''
    # No update if the buffer is not adequately populated
    if len(prm) < prm.batch_size:
        return
    
    # Sampled transitions and the importance sampling for each
    transitions, importance_sampling_weights = prm.sample()
    # Create the batch from the sampled transitions
    batch = Transition(*zip(*transitions))
    # A boolean tensor (true if the next state is not None and false, otherwise)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                        batch.next_state)), device=device, dtype=torch.bool)
    # Concat only next states that are not None (with a non zero Q)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                 if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Get the action values (expected Q) for each state and the action chosen
    Q_expected = policy_net(state_batch).gather(1, action_batch)
    # Get the Q for the next state
    next_state_values = torch.zeros(prm.batch_size, device=device)
    
    with torch.no_grad():
        # Get the action for the next state whose expected value
        # evaluated by the policy network is maximum (argmax Q(next_state, action; θ))
        next_state_actions = policy_net(non_final_next_states).max(1).indices.view(-1, 1)
        # Get the value of the action (selected above) using
        # the target network (Q(next_state, action; θ'))
        next_state_values[non_final_mask] = target_net(non_final_next_states) \
            .gather(1, next_state_actions).squeeze(1)
    
    # Q-targets
    Q_targets = reward_batch + GAMMA * next_state_values

    # New TD-errors for the sampled transitions
    td_errors =  Q_targets.detach().numpy() - Q_expected.squeeze(1).detach().numpy()
    # Update the td errors in the memory for the transitions in the batch
    prm.update_td_errors(td_errors)
    
    # Huber loss (element-wise)
    loss = F.smooth_l1_loss(Q_expected, Q_targets.unsqueeze(1))
    loss = loss * importance_sampling_weights
    loss = torch.mean(loss)

    optimizer.zero_grad()
    loss.backward()
    # Apply the gradient clipping
    nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    # Update the parameters of the policy network
    optimizer.step()

In [228]:
num_episodes = 600

for episode in range(num_episodes):
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for _ in count():
        # Get the action using the current Q network
        action = select_action(state)
        # Get the next state, reward, status of the next state, ...
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated
        # If the next state is a terminal state, then set it to None
        if terminated:
            next_state = None
        else:
            # A row vector of shape (1, 4)
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        # Add the experience to the buffer 
        if len(prm) < prm.batch_size:
            prm.add_experience(prm.incremental_td_error ** prm.alpha, state, action, next_state, reward)
        else:
            prm.add_experience(prm.give_max_td_error(), state, action, next_state, reward)
        
        # Move to the next state
        state = next_state
        
        # Next replay iteration and optimization
        optimize_model()
        
        # Copy the parameters of the policy network to the target network
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
        # Update the parameters of the target network
        target_net.load_state_dict(target_net_state_dict)
        
        if done:
            # End of the episode
            break