In [None]:
%load_ext autoreload
%autoreload 2
from utils.plotting import get_colors, load_config, plot
from utils.data_handling import load_dqn_data
import numpy as np

#### Name explanations
* DQN -> standard DQN
* DAR_min^max -> Dynamic action repetition with small repetition and long repetition values
* tqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action to be concatenated to the state
* t-dqn -> TempoRL DQN with separate skip-DQN that expects the behaviour action as contextual input
* tdqn -> TempoRL DQN with shared state representation between the behavoiur and skip action outputs.

In [None]:
import json
import glob
import os
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sb

from scipy.signal import savgol_filter
    

# Somehow the plotting functionallity I ended up with was already covered for the tabular case.
# I should have just used the plot function from that.
def plotMultiple(data, ylim=None, title='', logStepY=False, max_steps=200, xlim=None, figsize=None,
                 alphas=None, smooth=5, savename=None, rewyticks=None, lenyticks=None,
                 skip_stdevs=[], dont_label=[], dont_plot=[], min_steps=None,
                 logRewY=False):
    """
    Simple plotting method that shows the test reward on the y-axis and the number of performed training steps
    on the x-axis.
    
    data -> (dict[agent name] -> list([rewards, lens, decs, train_steps, train_episodes])) the data to plot
    ylim -> (list) y-axis limit
    title -> (str) title on top of plot
    logStepY -> (bool) flag that indicates if the y-axis should be on log scale.
    max_steps -> (int) maximal episode length
    min_steps -> (int) optional minimum episode length. If not set assumes 1 as min
    xlim -> (list) x-axis limits
    figsize -> (list) dimensions of the figure
    alphas -> (dict[agent name] -> float) the alpha value to use for plotting of specific agents
    smooth -> (int) the window size for smoothing (has to be odd if used. < 0 deactivates this option)
    savename -> (str) filename to save the figure
    rewyticks -> (list) yticks for the reward plot
    lenyticks -> (list) yticks for the decisions plot
    skip_sdevs -> (list) list of names to not plot standard deviations for.
    dont_label -> (list) list of names to not label.
    dont_plot -> (list) list of names to not plot.
    logRewY -> (bool) flag that indicates if the reward y-axis should be on log scale.
    """
    
    if smooth and smooth > 0:
        degree = 2
        for agent in data:
            data[agent] = list(data[agent])  # we have to convert the tuple to lists
            data[agent][0] = list(data[agent][0])
            data[agent][0][0] = savgol_filter(data[agent][0][0], smooth, degree)  # smooth the mean reward
            data[agent][0][1] = savgol_filter(data[agent][0][1], smooth, degree)  # smooth the stdev reward
            data[agent][1] = list(data[agent][1])
            data[agent][1][0] = savgol_filter(data[agent][1][0], smooth, degree)  # smooth mean num steps
            data[agent][1][1] = savgol_filter(data[agent][1][1], smooth, degree)
            data[agent][2] = list(data[agent][2])
            data[agent][2][0] = savgol_filter(data[agent][2][0], smooth, degree)  # smooth mean decisions
            data[agent][2][1] = savgol_filter(data[agent][2][1], smooth, degree)

    colors, color_map = get_colors()
    

    cfg = load_config()
    sb.set_style(cfg['plotting']['seaborn']['style'])
    sb.set_context(cfg['plotting']['seaborn']['context']['context'],
                   font_scale=cfg['plotting']['seaborn']['context']['font scale'],
                   rc=cfg['plotting']['seaborn']['context']['rc2'])

    if figsize:
        fig, ax = plt.subplots(2, figsize=figsize, dpi=100, sharex=True)
    else:
        fig, ax = plt.subplots(2, figsize=(20, 10), dpi=100,sharex=True)
    ax[0].set_title(title)

    for agent in list(data.keys())[::-1]:
        if agent in dont_plot:
            continue
        try:
            alph = alphas[agent]
        except:
            alph = 1.
        color_name = color_map['dar'] if 'dar' in agent else color_map[agent]
        rew, lens, decs, train_steps, train_eps = data[agent]
        
        label = agent.upper()
        if agent in ['t-dqn', 'tdqn', 'tqn']:
            label = 't-DQN'
        elif agent in dont_label:
            label = None

        #### Plot rewards
        ax[0].step(train_steps[0], rew[0], where='post', c=colors[color_name], label=label,
                   alpha=alph, ls='-' if agent != 't-dqn' else '-.')
        if agent not in skip_stdevs:
            ax[0].fill_between(train_steps[0], rew[0]-rew[1], rew[0]+rew[1],
                               alpha=0.25 * alph, step='post',
                               color=colors[color_name])
        #### Plot lens
        ax[1].step(train_steps[0], decs[0], where='post',
                   c=np.array(colors[color_name]), ls='-',
                   alpha=alph)
        if agent not in skip_stdevs:
            ax[1].fill_between(train_steps[0], decs[0]-decs[1], decs[0]+decs[1],
                               alpha=0.125 * alph, step='post',
                               color=np.array(colors[color_name]))
        ax[1].step(train_steps[0], lens[0], where='post',
                   c=np.array(colors[color_name]) * .75, alpha=alph,
                   ls=':')
        
        if agent not in skip_stdevs:
            ax[1].fill_between(train_steps[0], lens[0]-lens[1], lens[0]+lens[1],
                               alpha=0.25 * alph, step='post',
                               color=np.array(colors[color_name]) * .75)
    #ax[0].semilogx()
    if rewyticks is not None:
        ax[0].set_yticks(rewyticks)
    if ylim:
        ax[0].set_ylim(ylim)
    if xlim:
        ax[0].set_xlim(xlim)
    ax[0].set_ylabel('Reward')
    if len(data) - len(dont_label) < 5:
        ax[0].legend(ncol=1, loc='best', handlelength=.75)
    ax[1].semilogx()
    if logStepY:
        ax[1].semilogy()
    if logRewY:
        ax[0].semilogy()
        
    ax[1].plot([-999, -999], [-999, -999], ls=':', c='k', label='all')
    ax[1].plot([-999, -999], [-999, -999], ls='-', c='k', label='dec')
    ax[1].legend(loc='best', ncol=1, handlelength=.75)
    if not min_steps:
        ax[1].set_ylim([1, max_steps])
    else:
        ax[1].set_ylim([min_steps, max_steps])
    if xlim:
        ax[1].set_xlim(xlim)
    ax[1].set_ylabel('#Actions')
    ax[1].set_xlabel('#Train Steps')
    if lenyticks is not None:
        ax[1].set_yticks(lenyticks)
    plt.tight_layout()
    if savename:
        plt.savefig(savename)

    plt.show()


def get_best_to_plot(data, aucs, tempoRL=None):
    """
    Simple method to filter which lines to plot.
    """
    to_plot = dict()

    if tempoRL is None:
        aucs = list(sorted(aucs, key=lambda x: x[1], reverse=True))
        for idx, auc in enumerate(aucs):
            if 't' in auc[0]:
                break
        to_plot[aucs[idx][0]] = data[aucs[idx][0]]  # the absolute best
    else:
        to_plot[tempoRL] = data[tempoRL]

    bv = -np.inf
    b = None
    for elem in aucs:
        if 'dar' not in elem[0]:
            continue
        elif elem[1] > bv:
            b, bv = elem[0], elem[1]
    to_plot[b] = data[b]
    
    
    to_plot['dqn'] = data['dqn']
    return to_plot

In [None]:
data = {}

data['tdqn'] = load_dqn_data('*', 'experiments/atari/pong/tdqn',
                             #debug=True,
                             max_steps=2.5*10**6)
data['dqn'] = load_dqn_data('*', 'experiments/atari/pong/dqn',
                            #debug=True,
                            max_steps=2.5*10**6)
data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/pong/dar',
                            #debug=True,
                            max_steps=2.5*10**6)

plotMultiple(data, title='Pong',
             ylim=[-22, 22], xlim=[10**4, 2.5*10**6],
             min_steps=10**2, max_steps=3000, lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],
             smooth=7, savename='pong_50_seeds.pdf')  #, logStepY=True)

In [None]:
data = {}

data['tdqn'] = load_dqn_data('*', 'experiments/atari/beam_rider/tdqn_3',
                             #debug=True,
                             max_steps=2.5*10**6)
data['dqn'] = load_dqn_data('*', 'experiments/atari/beam_rider/dqn_3',
                            #debug=True,
                            max_steps=2.5*10**6)
data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/beam_rider/dar_3',
                            #debug=True,
                            max_steps=2.5*10**6)

plotMultiple(data, title='BeamRider',
             ylim=[0, 600],
             xlim=[10**4, 2.5*10**6],
             max_steps=1000, rewyticks=[0, 150, 300, 450, 600],  #lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],
             smooth=7, savename='beamrider_15_seeds.pdf')  #, logStepY=True)

In [None]:
data = {}

data['tdqn'] = load_dqn_data('*', 'experiments/atari/freeway/tdqn_3',
                             #debug=True,
                             max_steps=2.5*10**6)
data['dqn'] = load_dqn_data('*', 'experiments/atari/freeway/dqn_3',
                            #debug=True,
                            max_steps=2.5*10**6)
data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/freeway/dar_3',
                            #debug=True,
                            max_steps=2.5*10**6)

plotMultiple(data, title='Freeway',
             ylim=[0, 35],
             xlim=[10**4, 2.5*10**6],
             max_steps=2100, rewyticks=[0, 11, 22, 33],  #lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],
             smooth=7, savename='freeway_15_seeds.pdf')  #, logStepY=True)

In [None]:
data = {}

data['tdqn'] = load_dqn_data('*', 'experiments/atari/ms_pacman/tdqn_3',
                             #debug=True,
                             max_steps=2.5*10**6)
data['dqn'] = load_dqn_data('*', 'experiments/atari/ms_pacman/dqn_3',
                            #debug=True,
                            max_steps=2.5*10**6)
data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/ms_pacman/dar_3',
                            #debug=True,
                            max_steps=2.5*10**6)

plotMultiple(data, title='MsPacman',
             ylim=[0, 750],
             xlim=[10**4, 2.5*10**6],
             max_steps=300, rewyticks=[0, 250, 500, 750],  #lenyticks=[10**2, 10**3, 2*10**3, 3*10**3],
             smooth=7, savename='mspacman_15_seeds.pdf')  #, logStepY=True)

In [None]:
data = {}

data['tdqn'] = load_dqn_data('*', 'experiments/atari/qbert_long/tdqn_3',
                             #debug=True,
                             max_steps=5*10**6)
data['dqn'] = load_dqn_data('*', 'experiments/atari/qbert_long/dqn_3',
                            #debug=True,
                            max_steps=5*10**6)
data[r'dar$_{1}^{10}$'] = load_dqn_data('*', 'experiments/atari/qbert_long/dar_3',
                            #debug=True,
                            max_steps=5*10**6)

plotMultiple(data, title='QBert',
             ylim=[0, 1000],
             xlim=[10**4, 5*10**6], logRewY=False,
             max_steps=225, min_steps=0, rewyticks=[0, 250, 500, 750, 1000], lenyticks=[0, 50, 100, 150, 200],
             smooth=7, savename='qbert_15_sees.pdf')  #, logStepY=True)

In [None]:
data = {}

data['tdqn'] = load_dqn_data('*', 'experiments/atari/qbert/tdqn_3',
                             #debug=True,
                             max_steps=2.5*10**6)
data['dqn'] = load_dqn_data('*', 'experiments/atari/qbert/dqn_3',
                            #debug=True,
                            max_steps=2.5*10**6)

plotMultiple(data, title='QBert',
             ylim=[0, 1000],
             xlim=[10**4, 2.5*10**6], logRewY=False,
             max_steps=225, min_steps=0, rewyticks=[0, 250, 500, 750, 1000], lenyticks=[0, 50, 100, 150, 200],
             smooth=7)  #, logStepY=True)