## Evaluation

Credits
- [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3)
- [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo)
- [Gymnasium](https://gymnasium.farama.org/)
- [rliable](https://github.com/google-research/rliable)

### Import utils

In [None]:
from scripts.evaluation_utils import *

### Define configs

In [None]:
# global theme for plotting
THEME = 'rliable' # recommended
THEME_CHOICES = ['rliable', 'rl_zoo3', 'default']

# text rendering in matplotlib (default is mathtext)
USE_LATEX = True

# evaluation npz log info
# set 'BASE_LOG_PATH' to the '--log-folder' that you set during training 
# if you wish to regenerate the post-processed evaluation results, e.g., post_processed_results_56.pkl
BASE_LOG_PATH = '' # to be set by user
ALGO = 'ppo'
WILDCARD = '**'
EVALUATION_FILE = 'evaluations.npz'
N_EVALUATION_FILE = 1 # per test case (i.e., per game, self-attention type, seed)
N_EVALUATION_CHECKPOINT = 50 # total no. of eval checkpoints
N_EVALUATION_EPISODE = 5 # no. of episodes per eval checkpoint
EVALUATION_TIMESTEP_KEY = 'timesteps'
EVALUATION_RESULT_KEY = 'results'
EVALUATION_EP_LEN_KEY = 'ep_lengths'

# experiment variables
NUM_TIMESTEPS = 1e7 # this is the total timesteps set in the ppo.yml
GAMES = [] # all games with 'NoFrameskip-v4' suffix
GAMES_56 = [] # games focused in this paper
GAMES_EXCLUDED_IN_AGENT57_NO_VERSION = ['Adventure', 'AirRaid', 'Carnival', 'ElevatorAction', 'JourneyEscape', 'Pooyan'] # games excluded by Agent57
GAME_VERSION = 'NoFrameskip-v4'
SELF_ATTN_TYPES = ['NA', 'SWA', 'CWRA', 'CWCA', 'CWRCA']
SEEDS = ['0', '1', '10', '42', '1234']
SEEDS_TO_RUNS = {'0':0, '1':1, '10':2, '42':3, '1234':4} # map seeds to column indices in 'mean_per_eval' array

# post-processed evaluation results in pickle file format
PICKLE_FILE_PATH_56 = 'data/post_processed_results_56.pkl'

# smoothing factors
SMOOTH_WEIGHT = 0.98  # weight used in Debiased Exponential Moving Average (DEMA)

# figures
FIGURE_PATH_EVALUATION = 'figures'
LINEPLOT_PATH = 'lineplot' # learning curves
RLIABLE_PATH = 'rliable' # overall performance with stratified bootstrap CIs
DEMA_PATH = 'debiased_ema'
UNSMOOTH_PATH = 'unsmoothed'
FIGURE_EXT = '.pdf'
DPI = 300

# tables
TABLE_PATH_EVALUATION = 'tables'
TABLE_MARKDOWN_EXT = '.md' # for results table in markdown format
TABLE_LATEX_EXT = '.txt' # for results table in latex format

# winners
WINNER_PATH_EVALUATION = 'winners'
WINNER_EXT = '.json' # for winning games

In [None]:
# get all Atari games with version 'NoFrameskip-v4'
atari_game_list=[]
for key, value in gym.envs.registry.items():
    if 'NoFrameskip-v4' in key and '-ram' not in key and 'AtariEnv' in value.entry_point: 
        atari_game_list.append(key)

In [None]:
# set GAMES
GAMES = atari_game_list.copy()

In [None]:
# add game version to each game in GAMES_EXCLUDED_IN_AGENT57_NO_VERSION
GAMES_EXCLUDED_IN_AGENT57 = [game+GAME_VERSION for game in GAMES_EXCLUDED_IN_AGENT57_NO_VERSION]

In [None]:
# set GAMES_56
GAMES_56 = [game for game in GAMES if game not in GAMES_EXCLUDED_IN_AGENT57]

### Set theme for all the plots

In [None]:
set_theme(THEME)

### Re-generate post-processed results (optional)

You may uncomment the following cell to re-generate the post-processed evaluation results. Do remember to set the `BASE_LOG_PATH` in the `Configs` section.

In [None]:
# game_list = GAMES_56.copy()
# sat_list = SELF_ATTN_TYPES.copy()
# seed_list = SEEDS.copy()
# learning_curves = False # set to True to see all learning curves
# results_table = False # set to True to see the results table
# verbose = False # Set to True to see more logs
# save_file = True
# output_fp = PICKLE_FILE_PATH_56

# post_processed_results_56 = load_all_evaluation_files(game_list, sat_list, seed_list, \
#                                                       learning_curves=learning_curves, results_table=results_table, verbose=verbose, \
#                                                       save_file=save_file, output_fp=output_fp)

### Create results table

Get results table using the **all evaluation scores**, i.e., `use_last=False` in **Markdown** format

In [None]:
results_table_df_56_all, results_table_markdown_56_all, winning_games_56_all = get_results_table(post_processed_results_56, use_last=False, \
                                                                                                 markdown=True, latex=False, save_table_as_md=True, \
                                                                                                 save_table_as_txt=False, save_winner_as_json=True)

You may copy and paste the above cell output in this Markdown cell to better visualize the table.

In [None]:
# check the winning games for each sat
for sat in SELF_ATTN_TYPES:
    print(f"'\033[1m{sat}\033[0m': {winning_games_56_all[sat]}")

Get results table using the **all evaluation scores**, i.e., `use_last=False` in **LaTeX** format

In [None]:
results_table_df_56_all, results_table_markdown_56_all, winning_games_56_all = get_results_table(post_processed_results_56, use_last=False, \
                                                                                                 markdown=False, latex=True, save_table_as_md=False, \
                                                                                                 save_table_as_txt=True, save_winner_as_json=True)

In [None]:
results_table_df_56_all

### Plot performance per game

Without smoothing

In [None]:
results = post_processed_results_56
game_list = GAMES_56.copy()
sat_list = SELF_ATTN_TYPES.copy()
seed_list = SEEDS.copy()
smoothing = False
smooth_weight = 0.98
no_million = False
hue = 'a'
n_boot = 10
seed_boot = 42
linewidth = 2
suptitle = None
fontsize_suptitle = 30
position_suptitle = (0.5, 1.1)
figsize = (40,60)
fontsize_subtitle = 20
ncols = 6
fontsize_legend = 20
legend_title = 'SAT'
fontsize_legend_title = 22
savefig = True

plot_grouped_results_table(results, game_list, sat_list, seed_list, smoothing=smoothing, smooth_weight=smooth_weight, no_million=no_million, \
                           hue=hue, n_boot=n_boot, seed_boot=seed_boot, linewidth=linewidth, \
                           suptitle=suptitle, fontsize_suptitle=fontsize_suptitle, position_suptitle=position_suptitle, \
                           figsize=figsize, fontsize_subtitle=fontsize_subtitle, ncols=ncols, \
                           fontsize_legend=fontsize_legend, legend_title=legend_title, fontsize_legend_title=fontsize_legend_title, \
                           savefig=savefig)

With DEMA smoothing

In [None]:
results = post_processed_results_56
game_list = GAMES_56.copy()
sat_list = SELF_ATTN_TYPES.copy()
seed_list = SEEDS.copy()
smoothing = True
smooth_weight = 0.98
no_million = False
hue = 'a'
n_boot = 10
seed_boot = 42
linewidth = 2
suptitle = None
fontsize_suptitle = 30
position_suptitle = (0.5, 1.1)
figsize = (40,60)
fontsize_subtitle = 20
ncols = 6
fontsize_legend = 20
legend_title = 'SAT'
fontsize_legend_title = 22
savefig = True

plot_grouped_results_table(results, game_list, sat_list, seed_list, smoothing=smoothing, smooth_weight=smooth_weight, no_million=no_million, \
                           hue=hue, n_boot=n_boot, seed_boot=seed_boot, linewidth=linewidth, \
                           suptitle=suptitle, fontsize_suptitle=fontsize_suptitle, position_suptitle=position_suptitle, \
                           figsize=figsize, fontsize_subtitle=fontsize_subtitle, ncols=ncols, \
                           fontsize_legend=fontsize_legend, legend_title=legend_title, fontsize_legend_title=fontsize_legend_title, \
                           savefig=savefig)

### rliable plots

Copy the Human and Random scores from [deep_rl_precipice_colab.ipynb](https://github.com/google-research/rliable/blob/master/deep_rl_precipice_colab.ipynb)

In [None]:
score_str = """alien 7127.70 227.80 297638.17 ± 37054.55 464232.43 ± 7988.66 741812.63
amidar 1719.50 5.80 29660.08 ± 880.39 31331.37 ± 817.79 28634.39
assault 742.00 222.40 67212.67 ± 6150.59 110100.04 ± 346.06 143972.03
asterix 8503.30 210.00 991384.42 ± 9493.32 999354.03 ± 12.94 998425.00
asteroids 47388.70 719.10 150854.61 ± 16116.72 431072.45 ± 1799.13 6785558.64
atlantis 29028.10 12850.00 1528841.76 ± 28282.53 1660721.85 ± 14643.83 1674767.20
bank_heist 753.10 14.20 23071.50 ± 15834.73 27117.85 ± 963.12 1278.98
battle_zone 37187.50 2360.00 934134.88 ± 38916.03 992600.31 ± 1096.19 848623.00
beam_rider 16926.50 363.90 300509.80 ± 13075.35 390603.06 ± 23304.09 4549993.53
berzerk 2630.40 123.70 61507.83 ± 26539.54 77725.62 ± 4556.93 85932.60
bowling 160.70 23.10 251.18 ± 13.22 161.77 ± 99.84 260.13
boxing 12.10 0.10 100.00 ± 0.00 100.00 ± 0.00 100.00
breakout 30.50 1.70 790.40 ± 60.05 863.92 ± 0.08 864.00
centipede 12017.00 2090.90 412847.86 ± 26087.14 908137.24 ± 7330.99 1159049.27
chopper_command 7387.80 811.00 999900.00 ± 0.00 999900.00 ± 0.00 991039.70
crazy_climber 35829.40 10780.50 565909.85 ± 89183.85 729482.83 ± 87975.74 458315.40
defender 18688.90 2874.50 677642.78 ± 16858.59 730714.53 ± 715.54 839642.95
demon_attack 1971.00 152.10 143161.44 ± 220.32 143913.32 ± 92.93 143964.26
double_dunk -16.40 -18.60 23.93 ± 0.06 24.00 ± 0.00 23.94
enduro 860.50 0.00 2367.71 ± 8.69 2378.66 ± 3.66 2382.44
fishing_derby -38.70 -91.70 86.97 ± 3.25 90.34 ± 2.66 91.16
freeway 29.60 0.00 32.59 ± 0.71 34.00 ± 0.00 33.03
frostbite 4334.70 65.20 541280.88 ± 17485.76 309077.30 ± 274879.03 631378.53
gopher 2412.50 257.60 117777.08 ± 3108.06 129736.13 ± 653.03 130345.58
gravitar 3351.40 173.00 19213.96 ± 348.25 21068.03 ± 497.25 6682.70
hero 30826.40 1027.00 114736.26 ± 49116.60 49339.62 ± 4617.76 49244.11
ice_hockey 0.90 -11.20 63.64 ± 6.48 86.59 ± 0.59 67.04
jamesbond 302.80 29.00 135784.96 ± 9132.28 158142.36 ± 904.45 41063.25
kangaroo 3035.00 52.00 24034.16 ± 12565.88 18284.99 ± 817.25 16763.60
krull 2665.50 1598.00 251997.31 ± 20274.39 245315.44 ± 48249.07 269358.27
kung_fu_master 22736.30 258.50 206845.82 ± 11112.10 267766.63 ± 2895.73 204824.00
montezuma_revenge 4753.30 0.00 9352.01 ± 2939.78 3000.00 ± 0.00 0.00
ms_pacman 6951.60 307.30 63994.44 ± 6652.16 62595.90 ± 1755.82 243401.10
name_this_game 8049.00 2292.30 54386.77 ± 6148.50 138030.67 ± 5279.91 157177.85
phoenix 7242.60 761.40 908264.15 ± 28978.92 990638.12 ± 6278.77 955137.84
pitfall 6463.70 -229.40 18756.01 ± 9783.91 0.00 ± 0.00 0.00
pong 14.60 -20.70 20.67 ± 0.47 21.00 ± 0.00 21.00
private_eye 69571.30 24.90 79716.46 ± 29515.48 40700.00 ± 0.00 15299.98
qbert 13455.00 163.90 580328.14 ± 151251.66 777071.30 ± 190653.94 72276.00
riverraid 17118.00 1338.50 63318.67 ± 5659.55 93569.66 ± 13308.08 323417.18
road_runner 7845.00 11.50 243025.80 ± 79555.98 593186.78 ± 88650.69 613411.80
robotank 11.90 2.20 127.32 ± 12.50 144.00 ± 0.00 131.13
seaquest 42054.70 68.40 999997.63 ± 1.42 999999.00 ± 0.00 999976.52
skiing -4336.90 -17098.10 -4202.60 ± 607.85 -3851.44 ± 517.52 -29968.36
solaris 12326.70 1236.30 44199.93 ± 8055.50 67306.29 ± 10378.22 56.62
space_invaders 1668.70 148.00 48680.86 ± 5894.01 67898.71 ± 1744.74 74335.30
star_gunner 10250.00 664.00 839573.53 ± 67132.17 998600.28 ± 218.66 549271.70
surround 6.50 -10.00 9.50 ± 0.19 10.00 ± 0.00 9.99
tennis -8.30 -23.80 23.84 ± 0.10 24.00 ± 0.00 0.00
time_pilot 5229.20 3568.00 405425.31 ± 17044.45 460596.49 ± 3139.33 476763.90
tutankham 167.60 11.40 2354.91 ± 3421.43 483.78 ± 37.90 491.48
up_n_down 11693.20 533.40 623805.73 ± 23493.75 702700.36 ± 8937.59 715545.61
venture 1187.50 0.00 2623.71 ± 442.13 2258.93 ± 29.90 0.40
video_pinball 17667.90 0.00 992340.74 ± 12867.87 999645.92 ± 57.93 981791.88
wizard_of_wor 4756.50 563.50 157306.41 ± 16000.00 183090.81 ± 6070.10 197126.00
yars_revenge 54576.90 3092.90 998532.37 ± 375.82 999807.02 ± 54.85 553311.46
zaxxon 9173.30 32.50 249808.90 ± 58261.59 370649.03 ± 19761.32 725853.90"""

In [None]:
ALL_HUMAN_SCORES, ALL_RANDOM_SCORES = get_reference_scores(score_str, games_to_skip=['Surround'])

Obtain Human Normalized Scores (HNS)

In [None]:
results = post_processed_results_56
sat_list = SELF_ATTN_TYPES.copy()
game_list = GAMES_56.copy()
n_games = len(game_list)
n_runs = len(SEEDS)
n_eval = N_EVALUATION_CHECKPOINT

last_eval_hns_dict_56, mean_eval_hns_dict_56, all_eval_hns_dict_56 = get_hns_dict(results, sat_list, game_list)

#### Aggregate performance

Use `mean evals`

In [None]:
# use mean evals
hns_dicts = [last_eval_hns_dict_56, mean_eval_hns_dict_56]
use_last = False
task_bootstrap = False
reps = 2000
seed = None
subfigure_width = 3.5
row_height = 0.6
xlabel_y_coord = -0.02
interval_height = 0.6
wspace = 0.11
adjust_bottom = 0.2
savefig = True
figname = "Aggregate performance (bootstrap over runs and use mean evals)"
plot_aggregate_performance(hns_dicts, use_last=use_last, task_bootstrap=task_bootstrap, reps=reps, seed=seed, \
                           subfigure_width=subfigure_width, row_height=row_height, xlabel_y_coord=xlabel_y_coord, \
                           interval_height=interval_height, wspace=wspace, adjust_bottom=adjust_bottom, \
                           savefig=savefig, figname=figname)

#### Performance profile

Use `mean evals`

In [None]:
# use mean evals, use score dist, with inset plot
hns_dicts = [last_eval_hns_dict_56, mean_eval_hns_dict_56]
use_last = False
tau_start = 0
tau_stop = 8
tau_num = 81
use_score_distribution = True
task_bootstrap = False
reps = 2000
seed = None
alpha = 0.15
figsize = (7, 5)
linestyles=None
linewidth = 2.0
inset = True
inset_x_coord = 0.23
inset_y_coord = 0.4
inset_width = 0.55
inset_height = 0.45
inset_xlim_lower = 1
inset_xlim_upper = 4
inset_ylim_lower = 0.0
inset_ylim_upper = 0.25
inset_xticks = [1, 2, 3, 4]
legend_loc = 'best'
savefig = True # to save the figure, turn on plt.tight_layout()
figname = "Performance profiles with inset (bootstrap over runs and use mean evals, use score dist)"

plot_performance_profile(hns_dicts, use_last=use_last, tau_start=tau_start, tau_stop=tau_stop, tau_num=tau_num, \
                         use_score_distribution=use_score_distribution, task_bootstrap=task_bootstrap, \
                         reps=reps, seed=seed, figsize=figsize, alpha=alpha, linestyles=linestyles, linewidth=linewidth, \
                         inset=inset, inset_x_coord=inset_x_coord, inset_y_coord=inset_y_coord, inset_width=inset_width, inset_height=inset_height, \
                         inset_xlim_lower=inset_xlim_lower, inset_xlim_upper=inset_xlim_upper, \
                         inset_ylim_lower=inset_ylim_lower, inset_ylim_upper=inset_ylim_upper, inset_xticks=inset_xticks, \
                         legend_loc=legend_loc, savefig=savefig, figname=figname)

#### Probability of improvement

Use `mean evals`

In [None]:
# use mean evals
hns_dicts = [last_eval_hns_dict_56, mean_eval_hns_dict_56]
use_last = False
task_bootstrap = False
reps = 1000
seed = None
figsize = (8, 6)
alpha = 0.75
interval_height = 0.6
wrect = 5
ticklabelsize = 'x-large'
labelsize = 'x-large'
ylabel_x_coordinate = 0.08
savefig = True
figname = "Probability of improvement (bootstrap over runs and use mean evals)"

plot_prob_of_improvement(hns_dicts, use_last=use_last, task_bootstrap=task_bootstrap, reps=reps, seed=seed, figsize=figsize, alpha=alpha, \
                         interval_height=interval_height, wrect=wrect, ticklabelsize=ticklabelsize, labelsize=labelsize, \
                         ylabel_x_coordinate=ylabel_x_coordinate, savefig=savefig, figname=figname)

Use `mean evals` and set `algo_X`

In [None]:
# use mean evals
for sat in SELF_ATTN_TYPES.copy():
    hns_dicts = [last_eval_hns_dict_56, mean_eval_hns_dict_56]
    use_last = False
    task_bootstrap = False
    reps = 1000
    seed = None
    algo_X = sat
    figsize = (4, 3)
    alpha = 0.75
    interval_height = 0.6
    wrect = 5
    ticklabelsize = 'x-large'
    labelsize = 'x-large'
    ylabel_x_coordinate = 0.15
    savefig = True
    figname = f"Probability of improvement - {algo_X} vs rest (bootstrap over runs and use mean evals)"
    
    plot_prob_of_improvement(hns_dicts, use_last=use_last, task_bootstrap=task_bootstrap, reps=reps, seed=seed, algo_X=algo_X, \
                             figsize=figsize, alpha=alpha, interval_height=interval_height, wrect=wrect, ticklabelsize=ticklabelsize, \
                             labelsize=labelsize, ylabel_x_coordinate=ylabel_x_coordinate, savefig=savefig, figname=figname)

#### Sample Efficiency Curves

In [None]:
hns_dict = all_eval_hns_dict_56
results = post_processed_results_56
downsample_factor = 5
task_bootstrap = False
reps = 2000
seed = None
figsize=(7,5)
labelsize='xx-large'
ticklabelsize='xx-large'
marker = 'o'
linewidth = 2
savefig = True
figname = "Sample efficiency (bootstrap over runs)"

plot_sample_efficiency(hns_dict, results, downsample_factor=downsample_factor, task_bootstrap=task_bootstrap, \
                       reps=reps, seed=seed, figsize=figsize, ticklabelsize=ticklabelsize, labelsize=labelsize, \
                       marker=marker, linewidth=linewidth, savefig=savefig, figname=figname)