In [None]:
%load_ext autoreload
%autoreload 2


In [None]:
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
from cycler import cycler
from matplotlib import lines

import __init__
from src.evaluation.data import (compute_iqm_time_series, compute_iqm_values, smooth_time_series, wandb_load_overview,
                                 wandb_load_runs, wandb_load_time_series)
from src.evaluation.plots import iqm_ci_plot, iqm_line_plot
from src.evaluation.utils import mm2in


---

#### Plotting setup

In [None]:
plt.rcParams.update({
    'axes.prop_cycle': cycler('color', ["#0173B2", "#DE8F05", "#029E73", "#D55E00", "#CC78BC",
                                        "#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"]),
    'axes.titlepad': 3.0,
    'axes.xmargin': 0.025,
    'axes.ymargin': 0.025,
    'axes.titlesize': 'medium',
    'axes.labelpad': 1.0,
    'axes.spines.right': False,
    'axes.spines.top': False,
    'font.size': 7,
    'font.family': 'serif',
    'font.serif': 'Times',
    'mathtext.fontset': 'custom',
    'mathtext.it': 'Times',
    'mathtext.rm': 'Times',
    'text.usetex': False,
    'grid.alpha': 0.1,
    'grid.color': '#000000',
    'legend.borderaxespad': 0.25,
    'legend.borderpad': 0.0,
    'legend.frameon': False,
    'legend.columnspacing': 1.0,
    'legend.handletextpad': 0.5,
    'legend.handlelength': 0.75,
    'lines.solid_capstyle': 'round',
    'lines.solid_joinstyle': 'round',
    'xtick.major.pad': 2.0,
    'xtick.major.size': 2.0,
    'xtick.minor.size': 1.0,
    'ytick.major.pad': 2.0,
    'ytick.major.size': 2.0,
    'ytick.minor.size': 1.0,
    'boxplot.meanprops.markeredgecolor': 'none',
    'boxplot.meanprops.markerfacecolor': 'C1',
    'boxplot.medianprops.color': 'C0',
    'boxplot.showmeans': False,
    'boxplot.vertical': False,
    'boxplot.meanprops.marker': '.',
    'boxplot.flierprops.marker': '.',
    'boxplot.flierprops.markersize': 4.0,
    'boxplot.flierprops.markeredgecolor': 'none',
    'boxplot.flierprops.markerfacecolor': '#0000007f',
    'boxplot.showmeans': True,
    'figure.constrained_layout.h_pad': 0.01,
    'figure.constrained_layout.hspace': 0.05,
    'figure.constrained_layout.use': True,
    'figure.constrained_layout.w_pad': 0.01,
    'figure.constrained_layout.wspace': 0.05
})


In [None]:
keypoints = ['cube', 'target', 'end-effector']

keypoint_labels = {
    'tracking_errors/affine/0': 'cube',
    'tracking_errors/affine/1': 'target',
    'tracking_errors/affine/4': 'end-effector'
}

thresholds = {
    'cube': 0.015,
    'target': 0.015,
    'end-effector': 0.1
}

labels = {
    'sae-keynet': 'KeyNet',
    'sae-keynet-vel-var-bg': 'KeyNet-vel-std-bg',
    'sae-basic-vel': 'Basic-vel',
    'sae-basic': 'Basic',
    'sae-basic-var': 'Basic-std',
    'sae-basic-bg': 'Basic-bg',
    'sae-basic-var-bg': 'Basic-std-bg',
    'sae-basic-vel-bg': 'Basic-vel-bg',
    'sae-basic-vel-var': 'Basic-vel-std',
    'sae-basic-vel-var-bg': 'Basic-vel-std-bg',
    'sae-basic-fp32': 'Basic-kp32',
    'sae-dsae': 'DSAE'
}


---

#### Load overview table with all runs

In [None]:
entity = 'jonasreiher'
project = 'sae-rl-SAE'


In [None]:
runs = wandb_load_runs(entity, project)
runs_all = wandb_load_overview(runs)


---

#### Filter, group, and count runs

In [None]:
# filter for completely finished runs
# group by RL experiment and SAE model

runs_grouped = runs_all[
    (runs_all['state'] == 'finished') &
    (runs_all['global_step'] == 499)
].set_index(['dataset.id', 'experiment', 'name']).sort_index()

# runs_grouped = runs_all.set_index(['dataset.id', 'experiment', 'name']).sort_index()


In [None]:
# count number of runs for every group

runs_grouped.groupby(
    ['dataset.id', 'experiment'], dropna=False
)[['run']].count()


---
---

**Reconstruction Loss**

In [None]:
df = runs_grouped.loc['panda_push_custom', 'sae-basic-fp32']


In [None]:
time_series = wandb_load_time_series(df, 'valid/loss')
time_series_smooth = smooth_time_series(time_series)


In [None]:
fig, ax = plt.subplots()

ax.plot(time_series_smooth.index, time_series_smooth.values, color='C0', alpha=0.25)
ax.set_yscale('log')
ax.set_ylim(3e-4, 2e-2)
ax.set_ylabel('rec. loss')
ax.set_xlabel('training epochs')

fig.set_size_inches(mm2in(122 * 0.49, 25))
fig.savefig('../local/paper/sae_basicfp32_loss.pdf')


---

In [None]:
df = runs_grouped.loc['panda_push_custom', 'sae-keynet']


In [None]:
time_series = wandb_load_time_series(df, 'valid/loss')
time_series_smooth = smooth_time_series(time_series)


In [None]:
fig, ax = plt.subplots()

ax.plot(time_series_smooth.index, time_series_smooth.values, color='C0', alpha=0.25)
ax.set_yscale('log')
ax.set_ylim(3e-4, 2e-2)
ax.set_ylabel('rec. loss')
ax.set_xlabel('training epochs')

fig.set_size_inches(mm2in(122 * 0.49, 25))
fig.savefig('../local/paper/sae_keynet_loss.pdf')


---
---

**Individual Tracking Errors**

In [None]:
df = runs_grouped.loc['panda_push_custom', 'sae-basic-fp32']


In [None]:
time_series_smooth = []

for i, var in enumerate(['tracking_errors/affine/0', 'tracking_errors/affine/1', 'tracking_errors/affine/4']):
    time_series = wandb_load_time_series(df, var)
    time_series_smooth.append(smooth_time_series(time_series))


In [None]:
fig, axes = plt.subplots(1, 3)

for i in range(3):
    axes[i].plot(time_series_smooth[i].index, time_series_smooth[i], color='C0', alpha=0.25)
    axes[i].set_title(keypoints[i])
    axes[i].axhline(thresholds[keypoints[i]], color='k', alpha=0.25, linestyle='--', lw=1)

    axes[i].set_xlabel('training epochs')
    axes[i].set_ylim(0, None)

    if i in [0, 1]:
        axes[i].set_ylim([0.0, 0.05])
    elif i == 2:
        axes[i].set_ylim([0.0, 0.2])

axes[0].set_ylabel(r'tracking error')

fig.set_size_inches(mm2in(122, 22))
fig.savefig('../local/paper/sae_basicfp32_trackingerrors.pdf')


In [None]:
fig, axes = plt.subplots(1, 3)

for i in range(3):
    axes[i].plot(time_series_smooth[i].index, time_series_smooth[i], color='C0', alpha=0.25)
    axes[i].set_title(keypoints[i])
    axes[i].axhline(thresholds[keypoints[i]], color='k', alpha=0.25, linestyle='--', lw=1)

    # axes[i].set_xlabel('training epochs')
    axes[i].set_ylim(0, None)

    if i in [0, 1]:
        axes[i].set_ylim([0.0, 0.05])
    elif i == 2:
        axes[i].set_ylim([0.0, 0.2])

axes[0].set_ylabel('tracking error')

fig.set_size_inches(mm2in(122, 19))
fig.savefig('../local/paper/sae_basicfp32_trackingerrors_combined.pdf')


---

In [None]:
df = runs_grouped.loc['panda_push_custom', 'sae-keynet-vel-var-bg']


In [None]:
time_series_smooth = []

for i, var in enumerate(['tracking_errors/affine/0', 'tracking_errors/affine/1', 'tracking_errors/affine/4']):
    time_series = wandb_load_time_series(df, var)
    time_series_smooth.append(smooth_time_series(time_series))


In [None]:
fig, axes = plt.subplots(1, 3)

for i in range(3):
    axes[i].plot(time_series_smooth[i].index, time_series_smooth[i], color='C0', alpha=0.25)
    axes[i].set_title(keypoints[i])
    axes[i].axhline(thresholds[keypoints[i]], color='k', alpha=0.25, linestyle='--', lw=1)

    axes[i].set_xlabel('training epochs')
    axes[i].set_ylim(0, None)

    if i in [0, 1]:
        axes[i].set_ylim([0.0, 0.05])
    elif i == 2:
        axes[i].set_ylim([0.0, 0.2])

axes[0].set_ylabel(r'tracking error')

fig.set_size_inches(mm2in(122, 22))
fig.savefig('../local/paper/sae_keynetvelvarbg_trackingerrors.pdf')


In [None]:
fig, axes = plt.subplots(1, 3)

for i in range(3):
    axes[i].plot(time_series_smooth[i].index, time_series_smooth[i], color='C0', alpha=0.25)
    # axes[i].set_title(keypoints[i])
    axes[i].axhline(thresholds[keypoints[i]], color='k', alpha=0.25, linestyle='--', lw=1)

    axes[i].set_xlabel('training epochs')
    axes[i].set_ylim(0, None)

    if i in [0, 1]:
        axes[i].set_ylim([0.0, 0.05])
    elif i == 2:
        axes[i].set_ylim([0.0, 0.2])

axes[0].set_ylabel('tracking error')

fig.set_size_inches(mm2in(122, 19))
fig.savefig('../local/paper/sae_keynetvelvarbg_trackingerrors_combined.pdf')


---

In [None]:
df_basic = runs_grouped.loc['panda_push_custom', 'sae-basic-fp32']
df_keynet = runs_grouped.loc['panda_push_custom', 'sae-keynet-vel-var-bg']


In [None]:
time_series_smooth_basic = []
time_series_smooth_keynet = []

for i, var in enumerate(['tracking_errors/affine/0', 'tracking_errors/affine/1', 'tracking_errors/affine/4']):
    time_series_basic = wandb_load_time_series(df_basic, var)
    time_series_keynet = wandb_load_time_series(df_keynet, var)
    time_series_smooth_basic.append(smooth_time_series(time_series_basic))
    time_series_smooth_keynet.append(smooth_time_series(time_series_keynet))


In [None]:
fig, axes = plt.subplots(1, 3)

for i in range(3):
    axes[i].plot(time_series_smooth_basic[i].index, time_series_smooth_basic[i], color='C0', alpha=0.25)
    axes[i].plot(time_series_smooth_keynet[i].index, time_series_smooth_keynet[i], color='C1', alpha=0.25)
    axes[i].set_title(keypoints[i])
    axes[i].axhline(thresholds[keypoints[i]], color='k', alpha=0.25, linestyle='--', lw=1)

    axes[i].set_xlabel('training epochs')
    axes[i].set_ylim(0, None)

    if i in [0, 1]:
        axes[i].set_ylim([0.0, 0.05])
    elif i == 2:
        axes[i].set_ylim([0.0, 0.2])

axes[0].set_ylabel(r'tracking error')

fig.set_size_inches(mm2in(122, 22))
fig.savefig('../local/paper/sae_trackingerrors_combined.pdf')


In [None]:
fig, ax = plt.subplots()
ax.plot([], [])
ax.plot([], [])
lines = fig.axes[0].lines

fig, ax = plt.subplots()
ax.axis('off')

ax.legend(lines, ['Basic-kp32', 'KeyNet-vel-std-bg'], ncols=2)

fig.set_size_inches(mm2in(122, 3.5))
fig.savefig('../local/paper/sae_legend.pdf')


---
---

**Boxplots**

In [None]:
datasets = ['panda_push_custom']
experiments = ['sae-dsae', 'sae-basic', 'sae-basic-fp32', 'sae-basic-vel-var-bg', 'sae-keynet', 'sae-keynet-vel-var-bg']
run_sets = list(product(datasets, experiments))

df = runs_grouped.filter_by({
    'dataset.id': datasets,
    'experiment': experiments
}).sort_index()


In [None]:
fig, axes = plt.subplots(1, 3)

for i, var in enumerate(['tracking_errors/affine/0', 'tracking_errors/affine/1', 'tracking_errors/affine/4']):
    data = []
    tick_labels = []
    for run_set in run_sets:
        data.append(df.loc[run_set][var])
        tick_labels.append(labels[run_set[1]])
    axes[i].boxplot(data, vert=False, widths=0.6)
    axes[i].axvline(thresholds[keypoints[i]], color='k', alpha=0.25, linestyle='--', lw=1)
    axes[i].set_xlabel(f'{keypoint_labels[var]} tracking error')
    axes[i].set_xlim(0, None)

    median = lines.Line2D([], [], marker='|', c='C0', ls='')
    mean = lines.Line2D([], [], marker='.', markeredgecolor='none', c='C1', ls='')
    axes[i].legend([median, mean], ['median', 'mean'], bbox_to_anchor=(1,1), loc='lower right', ncol=2)

axes[0].set_yticklabels(tick_labels)
axes[1].set_yticklabels([])
axes[2].set_yticklabels([])

fig.set_size_inches(mm2in(122, 9.2 + 6 * 2.5))
fig.savefig('../local/paper/sae_boxplots.pdf')


---

In [None]:
datasets = ['panda_push_custom']
experiments = ['sae-keynet', 'sae-keynet-vel-var-bg', 'sae-basic-vel', 'sae-basic', 'sae-basic-var', 'sae-basic-var-bg', 'sae-basic-vel-var', 'sae-basic-bg', 'sae-basic-vel-bg', 'sae-basic-vel-var-bg']
run_sets = list(product(datasets, experiments))

df = runs_grouped.filter_by({
    'dataset.id': datasets,
    'experiment': experiments
}).sort_index()


In [None]:
fig, axes = plt.subplots(1, 3)

for i, var in enumerate(['tracking_errors/affine/0', 'tracking_errors/affine/1', 'tracking_errors/affine/4']):
    data = []
    tick_labels = []
    for run_set in run_sets:
        data.append(df.loc[run_set][var])
        tick_labels.append(labels[run_set[1]])
    axes[i].boxplot(data, vert=False, widths=0.6, )
    axes[i].axvline(thresholds[keypoints[i]], color='k', alpha=0.25, linestyle='--', lw=1)
    axes[i].set_xlabel(f'{keypoint_labels[var]} tracking error')
    axes[i].set_xlim(0, None)

    median = lines.Line2D([], [], marker='|', c='C0', ls='')
    mean = lines.Line2D([], [], marker='.', markeredgecolor='none', c='C1', ls='')
    axes[i].legend([median, mean], ['median', 'mean'], bbox_to_anchor=(1,1), loc='lower right', ncol=2)

axes[0].set_yticklabels(tick_labels)
axes[1].set_yticklabels([])
axes[2].set_yticklabels([])

fig.set_size_inches(mm2in(122, 9.2 + 10 * 2.5))
fig.savefig('../local/paper/sae_ablations_boxplots.pdf')


---
---

**Tracking Capability**

In [None]:
df = runs_grouped.filter_by({
    'dataset.id': ['panda_push_custom'],
    'experiment': ['sae-keynet-vel-var-bg', 'sae-basic-vel-var-bg', 'sae-basic-fp32', 'sae-keynet', 'sae-basic', 'sae-dsae',  'sae-basic-vel', 'sae-basic', 'sae-basic-var', 'sae-basic-bg', 'sae-basic-var-bg', 'sae-basic-vel-bg', 'sae-basic-vel-var', 'sae-basic-var-bg']
}).sort_index()


---

In [None]:
tracking_errors = df[['tracking_errors/affine/0', 'tracking_errors/affine/1', 'tracking_errors/affine/4']]
tracking_capability_individual = (tracking_errors <= thresholds.values()).groupby('experiment').mean()


In [None]:
tracking_capability_individual = tracking_capability_individual.loc[['sae-dsae', 'sae-basic', 'sae-basic-fp32', 'sae-basic-vel-var-bg', 'sae-keynet', 'sae-keynet-vel-var-bg']]


In [None]:
fig, ax = plt.subplots()
left = 0
for i, kp in enumerate(['tracking_errors/affine/0', 'tracking_errors/affine/1', 'tracking_errors/affine/4']):
    ax.barh([labels[exp] for exp in tracking_capability_individual.index], tracking_capability_individual[kp].values, left=left, label=keypoint_labels[kp])
    left += tracking_capability_individual[kp].values

ax.xaxis.grid()
ax.set_axisbelow(True)
ax.set_xlabel('tracking capability')
ax.legend(loc='lower right', bbox_to_anchor=(1.0, 1.0), ncols=3)
ax.set_xlim(0, 3.05)

fig.set_size_inches(mm2in(122, 10.3 + 6 * 2.5))
fig.savefig('../local/paper/sae_trackingcapability.pdf')


---

In [None]:
tracking_errors = df[['tracking_errors/affine/0', 'tracking_errors/affine/1', 'tracking_errors/affine/4']]
tracking_capability_individual = (tracking_errors <= thresholds.values()).groupby('experiment').mean()


In [None]:
tracking_capability_individual = tracking_capability_individual.loc[['sae-keynet', 'sae-keynet-vel-var-bg', 'sae-basic-vel', 'sae-basic', 'sae-basic-var', 'sae-basic-var-bg', 'sae-basic-vel-var', 'sae-basic-bg', 'sae-basic-vel-bg', 'sae-basic-vel-var-bg']]


In [None]:
fig, ax = plt.subplots()
left = 0
for i, kp in enumerate(['tracking_errors/affine/0', 'tracking_errors/affine/1', 'tracking_errors/affine/4']):
    ax.barh([labels[exp] for exp in tracking_capability_individual.index], tracking_capability_individual[kp].values, left=left, label=keypoint_labels[kp])
    left += tracking_capability_individual[kp].values

ax.xaxis.grid()
ax.set_axisbelow(True)
ax.set_xlabel('tracking capability')
ax.legend(loc='lower right', bbox_to_anchor=(1.0, 1.0), ncols=3)
ax.set_xlim(0, 3.05)

fig.set_size_inches(mm2in(122, 10.3 + 10 * 2.5))
fig.savefig('../local/paper/sae_tc_ablations.pdf')


In [None]:
df.groupby(
    ['dataset.id', 'experiment'], dropna=False
)[['run']].count()
