In [None]:
import itertools
import pickle
import numpy as np
np.set_printoptions(precision=3)

In [2]:
from IPython.display import clear_output
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
palette = sns.color_palette()

In [3]:
import os
import imageio
import natsort

In [4]:
# # import os
# # os.listdir('./100K_run/')
exp_list = ['D1QN_D1QN_Naive_1000_freq_100_324267_04171400',
'D1QN-PER_D1QN_NaivePER_1000_freq_100_324267_04164957',
'DQN_DQN_Naive_1000_freq_100_324267_04171233',
'DQN-PER_DQN_NaivePER_1000_freq_100_324267_04164826',
'DQN-PER-original_DQN_NaivePER_100000_freq_1000_324267_04163756',

'D2QN_D2QN_Naive_1000_freq_100_324267_04170846',
'D2QN-PER_D2QN_NaivePER_1000_freq_100_324267_04165444',

'DuDQN_DuDQN_Naive_1000_freq_100_324267_04170649',
'DuDQN-PER_DuDQN_NaivePER_1000_freq_100_324267_04165300',

'DuD2QN_DuD2QN_Naive_1000_freq_100_324267_04170223',
'DuD2QN-PER_DuD2QN_NaivePER_1000_freq_100_324267_04165910']

In [5]:
exp_dict = {}
for exp in exp_list:
    experiment = exp.split('_')[0]
    exp_dict[experiment] = exp

In [6]:
cart_pos_threshold = 2.4
theta_threshold = 12 * 2 * np.pi / 360 # ~ 0.21
pole_ang_threshold = 1*theta_threshold

In [None]:
# Identify successful terminal states and plot the kde of the 2d state
for exp in ['D1QN','D1QN-PER','DQN','DQN-PER','D2QN','D2QN-PER','DuDQN','DuDQN-PER','DuD2QN','DuD2QN-PER']:
    log_name = exp_dict[exp]

    MEM_FILE = './memories/' + log_name + '.mpk'

    # Load Memories
    with open(MEM_FILE, 'rb') as fpr:
        memories = np.array(list(pickle.load(fpr)))

    visited_states = np.stack(memories[:,0]).squeeze()
    actions = memories[:,1].astype(np.float32)
    rewards = memories[:,2].astype(np.float32)
    next_states = np.stack(memories[:,3]).squeeze()
    done = memories[:,4].astype(np.bool)

    terminal_mem = memories[memories[:,4]==True]
    terminal_states = np.stack(terminal_mem[:,0]).squeeze()

    success_states = terminal_states[np.abs(terminal_states[:,0]) < cart_pos_threshold*0.8]
    success_states = success_states[np.abs(success_states[:,2]) < pole_ang_threshold*0.8] 

    success_termination_pos = success_states[:,0]
    success_termination_ang = success_states[:,2]

    import pandas as pd
    # Combine data into DataFrame
    df = pd.DataFrame({'Cart Position': success_termination_pos, 
                       'Pole Angle': success_termination_ang
                      })

    # Define colormap and create corresponding color palette
    cmap = sns.diverging_palette(20, 220, as_cmap=True)

    # Plot data onto seaborn JointGrid
    g = sns.JointGrid('Cart Position', 'Pole Angle',
                      data=df, 
                      ratio=8,
                      xlim=[-cart_pos_threshold, cart_pos_threshold],
                      ylim=[-pole_ang_threshold, pole_ang_threshold])
    g = g.plot_joint(sns.kdeplot,
                     color=palette[3],
                     shade=False,
                     alpha = 0.3
                    )
    g = g.plot_joint(plt.scatter,
                     s = 10,
                     alpha = 0.7,
                     color=palette[3])
    sns.kdeplot(df['Cart Position'], ax=g.ax_marg_x, vertical=False, color=palette[4], shade=True, legend=False)
    sns.kdeplot(df['Pole Angle'], ax=g.ax_marg_y, vertical=True, color=palette[5], shade=True, legend = False)
    
    fig = g.fig
    fig.subplots_adjust(top=0.93)
    fig.suptitle(exp, fontsize=16, fontweight='bold')


    FIG_NAME = ("success_"+ exp + ".png")
#     fig.savefig(FIG_NAME,
              dpi=300,
              format='png')
    fig.close()

In [84]:
# Make composite image by stacking all the previously generated images
from PIL import Image
fig_list = []
for exp in ['D1QN','D1QN-PER','DQN','DQN-PER','D2QN','D2QN-PER','DuDQN','DuDQN-PER','DuD2QN','DuD2QN-PER']:
    FIG_NAME = ("success_"+ exp + ".png")
    fig_list.append(Image.open(FIG_NAME).convert("RGBA"))

for fig in fig_list:
    fig.putalpha(int(np.floor(125)))


for i in range(1,len(fig_list)):
    fig_list[0].alpha_composite(fig_list[i])

# fig_list[0].save('composite_success.png',format='png')