In [None]:
import numpy as np
import matplotlib.pyplot as plt
import gym
import tensorflow as tf
import tensorflow.keras as keras # type: ignore
from tensorflow.keras import layers # type: ignore
from tqdm import tqdm
from collections import deque
import random

In [None]:
env = gym.make("CartPole-v1")
num_actions = env.action_space.n
num_states = env.observation_space.shape[0]

print('num_actions:', num_actions)
print('num_states:', num_states)

In [None]:
# def create_q_model():
#     inputs = layers.Input(shape=(num_states,))
#     x = layers.Dense(24, activation="relu")(inputs)
#     x = layers.Dense(24, activation="relu")(x)
#     action = layers.Dense(num_actions, activation="linear")(x)
#     return keras.Model(inputs=inputs, outputs=action)

def create_q_model(num_states, num_actions):
    inputs = layers.Input(shape=(num_states,))
    x = layers.Dense(128, activation="linear")(inputs)
    x = layers.LeakyReLU(alpha=0.01)(x) 
    x = layers.Dense(128, activation="linear")(x)
    x = layers.LeakyReLU(alpha=0.01)(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(64, activation="linear")(x)
    x = layers.LeakyReLU(alpha=0.01)(x)
    action = layers.Dense(num_actions, activation="linear")(x)
    return keras.Model(inputs=inputs, outputs=action)

In [None]:
def reduce_nodes(model, reduction_factor):
    """Reduces the number of neurons in each Dense layer (in model) by a reduction factor"""
    new_model = keras.Sequential()
    previous_units = None

    for i, layer in enumerate(model.layers):
        if isinstance(layer, layers.Dense):

            config = layer.get_config() # Refer to dictionary keys on https://www.tensorflow.org/guide/keras/serialization_and_saving
            current_units = config['units'] 

            if i == len(model.layers) - 1: # Output layer should not be pruned
                new_units = current_units 
            else: 
                new_units = max(int(current_units * reduction_factor), 1) # Neurons in Dense layer cannot be less than 1

            new_layer = layers.Dense(new_units, activation=config['activation'])

            params = layer.get_weights()

            if params:
                weights, biases = params

                if previous_units is not None: # If there was a dense layer before this, then... 
                    resized_weights = weights[:previous_units, :new_units] # resize mapping the previous layer to the current layer
                else: # If there was not a dense layer before this, then...
                    resized_weights = weights[:, :new_units] # resize mapping the input layer to the current layer

                resized_biases = biases[:new_units]

                new_layer.build((None, previous_units if previous_units is not None else weights.shape[0])) # Build takes in INPUT SHAPE, not output shape

                new_layer.set_weights([resized_weights, resized_biases])

            previous_units = new_units
            new_model.add(new_layer)
            
        else: 
            new_model.add(layer)

    return new_model

def count_dense_nodes(model):
    """Counts the total number of neurons in all Dense layers"""
    total_nodes = 0
    for layer in model.layers:
        if isinstance(layer, layers.Dense):
            total_nodes += layer.units
    return total_nodes

In [None]:
model_i = create_q_model(num_states, num_actions)
print(model_i.summary())

In [None]:
# initialize models
model = create_q_model(num_states, num_actions)
target_model = create_q_model(num_states, num_actions)
target_model.set_weights(model.get_weights())
# optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_function = tf.keras.losses.MeanSquaredError()

# initialize hyperparameters
pruning_threshold = 200
reduction_factor = 0.90
frequency = 20

def train_model():
    # initialize parameters
    epsilon = 1
    epsilon_min = 0.01
    epsilon_decay = 0.995
    gamma = 0.99
    batch_size = 32
    replay_buffer = deque(maxlen=10000) # maintain replay buffer to randomly sample previous observations for each backprop
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    # initialize running metrics
    episode_rewards = []
    average_rewards = []
    episode_of_pruning = []
    node_counts = []

    global model, target_model

    for episode in tqdm(range(2000)):

        def update_model(minibatch): # Note: must be inside loop to access episode
            global model, target_model

            # format minibatch history of states, actions, rewards, etc.
            states, actions, rewards, next_states, terminals = [np.array(list(x)) for x in zip(*minibatch)]
            states = np.squeeze(states)
            next_states = np.squeeze(next_states)
            
            # predict future discounted reward with target model
            target_q = rewards + (1 - terminals) * gamma * np.max(target_model.predict(next_states, verbose=0), axis=1)
            target_q_full = model.predict(states, verbose=0)
            indices = np.arange(batch_size)
            target_q_full[indices, actions] = target_q
            
            # backpropagation
            with tf.GradientTape() as tape:
                predicted_q_values = model(states)
                loss = loss_function(target_q_full, predicted_q_values)
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            # update target model every 10 episodes
            if episode % 10 == 0:
                target_model.set_weights(model.get_weights())

        # reset episode
        state, _ = env.reset()
        state = np.array(state, dtype=np.float32).reshape(1, -1)
        episode_reward = 0

        while True:
            # epsilon greedy action
            if random.random() < epsilon:
                action = env.action_space.sample()
            else:
                state_tensor = tf.convert_to_tensor(state, dtype=tf.float32)
                action_probs = model(state_tensor, training=False)
                action = np.argmax(action_probs.numpy())

            # step according to epsilon greedy action
            next_state, reward, terminal, _, _ = env.step(action)
            next_state = np.array(next_state, dtype=np.float32).reshape(1, -1) 

            # add step to the replay buffer
            replay_buffer.append((state, action, reward, next_state, terminal))

            state = next_state
            episode_reward += reward

            # if reward is exceedingly high or episode has finished and the replay buffer has enough datapoints, update the model
            if episode_reward >= 1000 or terminal:
                if len(replay_buffer) > batch_size:
                    minibatch = random.sample(replay_buffer, batch_size)
                    update_model(minibatch)
                break

        # model pruning - note this section is the only addition to q-learning novel to DQP.
        average_reward = np.mean(episode_rewards[-frequency:]) if len(episode_rewards) >= frequency else np.mean(episode_rewards)
        average_rewards.append(average_reward)
        if episode % frequency == 0 and episode > 0:
            median_reward = np.median(episode_rewards[-frequency:]) if len(episode_rewards) >= frequency else np.median(episode_rewards)

            # prune IFF the running median reward exceeds the pruning threshold
            if median_reward > pruning_threshold:
                print('PRUNING MODEL')
                episode_of_pruning.append(episode)

                # call reduce nodes on model and reset optimizer
                model = reduce_nodes(model, reduction_factor)
                target_model = reduce_nodes(target_model, reduction_factor)
                optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
                target_model = tf.keras.models.clone_model(model)

        episode_rewards.append(episode_reward)
        node_counts.append(count_dense_nodes(model))
        epsilon = max(epsilon_min, epsilon * epsilon_decay)

    return episode_rewards, average_rewards, episode_of_pruning, node_counts

In [None]:
episode_rewards, average_rewards, episode_of_pruning, node_counts = train_model()

In [None]:
plt.figure(figsize=(12, 6))

# Plot for rewards
plt.subplot(2, 1, 1)
plt.plot(average_rewards, label=f'Average Episode Reward (Last {frequency} Episodes)')
plt.title("Rewards Over Training Episodes")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.axhline(y=pruning_threshold, color='r', linestyle='--', label='Pruning Threshold')
plt.axvline(x=episode_of_pruning[0], color='g', linestyle='--', label='Pruning Episode')
for episode in episode_of_pruning[1:]:
    plt.axvline(x=episode, color='g', linestyle='--')
plt.legend()

# Plot for node count
plt.subplot(2, 1, 2)
plt.plot(range(len(node_counts)), node_counts, color='b', label='Number of Nodes in Dense Layers')
plt.title("Number of Nodes Over Training Episodes")
plt.xlabel("Episode")
plt.ylabel("Number of Nodes")
plt.tight_layout()
plt.show()