In [1]:
# Comment the following line if you want to view the plots in a separate window:
# %matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pickle
from scipy.stats import rankdata

from DiscreteRLFlyEnv import DiscreteFlyEnv
from computations import DEG2RAD, RAD2DEG
from utils import map_range, smooth_array

Load a Trained Q-Table

In [2]:
q_table = np.load('train_results/q_table_final.npy')

print(f'Q Table Shape: {q_table.shape} = {q_table.shape[0] * q_table.shape[1] * q_table.shape[2]} entries')

Q Table Shape: (59, 9, 7) = 3717 entries


Create an environment with the same configuration that the Q-Table was trained on

In [3]:
state_space = q_table.shape[:2]
n_actions = q_table.shape[2]
env = DiscreteFlyEnv(config_path='config.json', state_space=state_space, action_space=n_actions)

print("Observation Space", env.observation_space)
print("Sample observation", env.observation_space.sample())  # display a random observation
print("Action Space Shape", env.action_space.n)
print("Action Space Sample", env.action_space.sample())


Observation Space MultiDiscrete([59  9])
Sample observation [16  0]
Action Space Shape 7
Action Space Sample 2


# Visualizing the Q-Table

Probably the best way: a heatmap, where the x-axis is the pitch-state buckets, y-axis is the delta-pitch buckets, and the values are the delta-phi values.

In [4]:
def q_table_heatmap(q_table, env: DiscreteFlyEnv):
    angle_bucket_size = (env.pitch_range[1] - env.pitch_range[0]) / q_table.shape[0]
    angle_buckets = np.round(np.arange(env.pitch_range[0], env.pitch_range[1], angle_bucket_size) * RAD2DEG, 2)
    delta_pitch_bucket_size = (env.delta_pitch_range[1] - env.delta_pitch_range[0]) / q_table.shape[1]
    delta_pitches = np.round(np.arange(env.delta_pitch_range[0], env.delta_pitch_range[1], delta_pitch_bucket_size), 2)

    q_aggregated = np.argmax(q_table, axis=-1).T
    q_aggregated = map_range(q_aggregated, 0, env.action_space.n-1, env.delta_phi_range[0], env.delta_phi_range[1])
    # Create a single heatmap
    plt.figure(figsize=(10, 8))
    sns.heatmap(q_aggregated, xticklabels=angle_buckets, yticklabels=delta_pitches, annot=False, fmt=".2f", cmap="coolwarm")
    plt.title('Heatmap of Q-table Aggregated Across Actions')
    plt.xlabel('Angles')
    plt.ylabel('Y-axis')
    plt.show()

q_table_heatmap(q_table, env)

Visualize via a line graph, where the x-axis is pitch buckets, y-axis is the delta-phi, each there is one line for each delta-pitch. 

In [5]:
def q_table_line_graph(q_table, env: DiscreteFlyEnv): 
    # Plot 2D line plots for a fixed state dimension (e.g., state dimension 1)
    angle_bucket_size = (env.pitch_range[1] - env.pitch_range[0]) / q_table.shape[0]
    angle_buckets = np.arange(env.pitch_range[0], env.pitch_range[1], angle_bucket_size) * RAD2DEG
    delta_pitch_bucket_size = (env.delta_pitch_range[1] - env.delta_pitch_range[0]) / q_table.shape[1]
    delta_pitches = np.arange(env.delta_pitch_range[0], env.delta_pitch_range[1], delta_pitch_bucket_size)

    plt.figure(figsize=(15, 5))
    for delta_pitch in delta_pitches:
        discrete_delta_pitch = int(map_range(delta_pitch, env.delta_pitch_range[0], env.delta_pitch_range[1], 0, env.observation_space.nvec[1] - 1))
        discrete_delta_phis = np.argmax(q_table[:, discrete_delta_pitch, :], axis=-1)
        delta_phis = map_range(discrete_delta_phis, 0, env.action_space.n-1, env.delta_phi_range[0], env.delta_phi_range[1])
        plt.plot(angle_buckets, delta_phis, label=f"Pitch Dot = {delta_pitch}")
        plt.scatter(angle_buckets, delta_phis)  # Add circles for every point

    plt.xlabel('Pitch')
    plt.ylabel('Delta Phi')
    # plt.gca().set_yticklabels(sorted(set(delta_phis)))    
    plt.title('Delta Phi per State')
    plt.legend()
    plt.show()

q_table_line_graph(q_table, env)

# Visualize the average pitch per episode throughout training

The Q-Table was trained to make the pitch converge to -45 degrees.

Load the pitch history throughout the training of the Q-Table:

In [6]:
with open('train_results/pitch_history_final.pickle', 'rb') as f:
    pitch_history = pickle.load(f)

curriculum_level_durations = [300, 300, 300, 300, 300, 400, 600] + [300, 300, 500] + [300, 300, 500] + [300, 400, 500]

print(f'Num episodes in pitch history: {len(pitch_history)}')
print(f'Median number of data-points per episode: {np.median([len(episode) for episode in pitch_history])}')

Num episodes in pitch history: 5879
Median number of data-points per episode: 60.0


In [7]:
def plot_angle_history(angles_per_episode, angle_name, level_durations):
    '''
    angles_per_episode: 2d array, where each element is a list of the angle's history in that episode
    angle_name: name of the angle being tracked (pitch, roll, yaw, etc.)
    level_durations: an array where each element is the number of episodes for that level
    '''

    # Extract episode means/std/max/min of each episode
    episode_means = [np.mean(episode_angles) for episode_angles in angles_per_episode]
    episode_stds = [np.std(episode_angles) for episode_angles in angles_per_episode]
    episode_max = [np.max(episode_angles) for episode_angles in angles_per_episode]
    episode_min = [np.min(episode_angles) for episode_angles in angles_per_episode]

    # Smoothen the arrays for nicer plotting
    alpha = 0.995
    smoothed_means = np.array(smooth_array(episode_means, alpha=alpha))
    smoothed_stds = np.array(smooth_array(episode_stds, alpha=alpha))
    smoothed_max = np.array(smooth_array(episode_max, alpha=alpha))
    smoothed_min = np.array(smooth_array(episode_min, alpha=alpha))

    # Display the plot
    plt.plot(smoothed_means, label=f'{angle_name} Means', color='r')
    plt.plot(smoothed_max, label=f'Max {angle_name}', alpha=0.4)
    plt.plot(smoothed_min, label=f'Min {angle_name}', alpha=0.4)
    plt.fill_between(range(len(smoothed_means)), smoothed_means - smoothed_stds, 
                     smoothed_means + smoothed_stds, color='b', alpha=0.2, label='1 Std Dev')
        

    # plt.plot(smooth_angle_devs, label='Smoothed Pitch Deviations')
    for i, level in enumerate(np.cumsum(level_durations)):
        plt.axvline(x=level - level_durations[0], color='orange', linestyle='--', linewidth=0.5, label=f'Curriculum Level {i+1}')
    plt.xlabel('Episode')
    plt.ylabel(f'{angle_name} [degrees]')
    plt.title(f'{angle_name} per Episode')
    plt.grid(alpha=0.1)
    plt.legend(loc='upper left')
    plt.show()


plot_angle_history(pitch_history, 'Pitch', curriculum_level_durations)

# Visualize the amount of time the fly spent in each state throughout training

In [8]:
state_heatmap_iter = np.load('train_results/state_heatmap_final.npy')

In [9]:
def show_state_heatmap(hm, title='Heatmap of Q-table Aggregated Across Actions'):
    angle_bucket_size = (env.pitch_range[1] - env.pitch_range[0]) / q_table.shape[0]
    angle_buckets = np.round(np.arange(env.pitch_range[0], env.pitch_range[1], angle_bucket_size) * RAD2DEG, 2)
    delta_pitch_bucket_size = (env.delta_pitch_range[1] - env.delta_pitch_range[0]) / q_table.shape[1]
    delta_pitches = np.round(np.arange(env.delta_pitch_range[0], env.delta_pitch_range[1], delta_pitch_bucket_size), 2)

    plt.figure(figsize=(10, 8))
    sns.heatmap(hm.T, xticklabels=angle_buckets, yticklabels=delta_pitches, annot=False, fmt=".2f", cmap="coolwarm")
    plt.title(title)
    plt.xlabel('Angles')
    plt.ylabel('Y-axis')
    plt.show()


def show_percentile_state_heatmap(hm, title='State Heatmap (Percentiles)'):
    q_table_marker = np.zeros_like(hm)

    ranks = rankdata(hm, method='average')
    percentiles = ranks / hm.reshape(-1, 1).shape[0] * 100

    show_state_heatmap(percentiles.reshape(q_table_marker.shape), title)

show_percentile_state_heatmap(state_heatmap_iter)