In [1]:
import os
os.chdir('../..')
parent_dir = os.path.abspath(os.getcwd())

In [2]:
import numpy as np
import pickle5 as pickle
from scipy.stats import chisquare
from scipy.stats import entropy
from scipy.stats import ks_2samp
import glob, sys
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.font_manager import _rebuild; _rebuild()
%matplotlib inline 
import seaborn as sns
from matplotlib.lines import Line2D

In [3]:
matplotlib.rc('font',family='serif', serif=['Palatino'])
sns.set_style('white')

pgf_with_rc_fonts = {"pgf.texsystem": "pdflatex"}
matplotlib.rcParams.update(pgf_with_rc_fonts)
matplotlib.rcParams['text.usetex'] = True

def set_style():
    sns.set(font='serif', font_scale=1.4)
    
   # Make the background white, and specify the
    # specific font family
    sns.set_style("white", {
        "font.family": "serif",
        "font.weight": "normal",
        "font.serif": ["Times", "Palatino", "serif"],
        'axes.facecolor': 'white',
        'lines.markeredgewidth': 1})
    
def plot_sig_line(ax, x1, x2, y1, h, padding=0.3):
    '''
    Plots the bracket thing denoting significance in plots. h controls how tall vertically the bracket is.
    Only need one y coordinate (y1) since the bracket is parallel to the x-axis.
    '''
    ax.plot([x1, x1, x2, x2], [y1, y1 + h, y1 + h, y1], linewidth=1, color='k')
    ax.text(0.5*(x1 + x2), y1 + h + padding * h, '*', color='k', fontsize=16, fontweight='normal')

In [None]:
def line_plot(data):
    fig,ax = plt.subplots(figsize=(20,10))
    fontsize = 42
    matplotlib.rcParams.update({'font.size': fontsize})
    plt.rcParams['xtick.labelsize']=fontsize
    plt.rcParams['ytick.labelsize']=fontsize

    set_style()
    ### TODO: change the method labels and the colors if desired.
    methods = ['SIRL+VAE (Ours)', 'SIRL (Ours)', 'SinglePref', 'MultiPref 10H', 'MultiPref 50H', 'Random', 'VAE']
    colors = ['#f79646', '#f79646', '#6283F7', '#6283F7', '#6283F7', '#D3D3D3', '#6a0dad']
    lines = ['o-', 'o--', 'o-', 'o--', 'o:', 'o-', 'o-']

    for i in range(len(methods)):
        # TODO: our metric was TPA; feel free to chage it to whatever name you want
        TPA_avg = np.mean(data[methods[i]], axis=(1,2))
        TPA_std = np.std(data[methods[i]], axis=(1,2)) / np.sqrt(np.prod(data[methods[i]].shape[1:]))
    
        # Make the plot
        ax.errorbar(np.linspace(10, 190, 10), TPA_avg, yerr=TPA_std, fmt=lines[i], color=colors[i], ecolor='black',\
                     linewidth=3, markersize=10, elinewidth=2, capsize=5, label=methods[i])

    # Add xticks on the middle of the group bars
    # TODO: change whatever labels you want here
    ax.set_xlabel('Number of Preference Queries', fontsize=fontsize)
    ax.set_ylabel(r'\textit{TPA}', fontsize=fontsize)
    ax.set_ylim([0.5, 1.0])
    ax.tick_params(axis="x", labelsize=fontsize)
    ax.tick_params(axis="y", labelsize=fontsize)
    ax.set_xticks(np.linspace(10, 190, 10))
    ax.set_title(r'JacoRobot with $N=1000$', fontsize=fontsize)

    # Create legend & Show graphic
    handles,labels = ax.get_legend_handles_labels()

    handles = [handles[0], handles[2], handles[1], handles[3], handles[5], handles[4], handles[6]]
    labels = [labels[0], labels[2], labels[1], labels[3], labels[5], labels[4], labels[6]]

    ax.legend(handles,labels, fontsize=30, ncol=4, mode="expand", frameon=False)
    sns.despine(fig)

    # TODO: i use plt.show to view inside jupyter, and plt.savefig to save the final figure
    #plt.show()
    plt.savefig(parent_dir + '/data/figures/TPA_jacorobot_N1000.pdf', format='pdf', bbox_inches='tight')

In [None]:
def bar_plot(SIRL, random):
    fontsize = 42

    fig = plt.figure(figsize=(30,10))
    matplotlib.rcParams.update({'font.size': fontsize})
    plt.rcParams['xtick.labelsize']=fontsize
    plt.rcParams['ytick.labelsize']=fontsize

    set_style()
    SIRL_color = '#f79646' # SIRL
    Random_color = '#7a7a7a' # Random
    labels = ["SIRL (Ours)", "Random"]

    # set width of bar
    barWidth = 0.33

    count = 0
    # TODO: you'll have 6 different keys
    for i, key in enumerate(['TPA_real', 'TPA_real6', 'TPA_9real', 'TPA_9real_6', 'TPA_sim', 'TPA_sim6']):
        # set height of bar
        bars1_mean = np.mean(SIRL[key])
        bars1_sem = np.std(SIRL[key]) / np.sqrt(len(SIRL[key]))
        bars2_mean = np.mean(random[key])
        bars2_sem = np.std(random[key]) / np.sqrt(len(random[key]))

        # Make the plot
        if i > 1:
            labels = [None]*3
        plt.bar(count, bars1_mean, yerr=bars1_sem, color=SIRL_color, width=barWidth,
                ecolor='black', error_kw=dict(elinewidth=2, capsize=5), label=labels[0])
        plt.bar(count+barWidth, bars2_mean, yerr=bars2_sem, color=Random_color, width=barWidth,
                ecolor='black', error_kw=dict(elinewidth=2, capsize=5), label=labels[1])

        count+=1
    # TODO: you may not need these two vertical lines
    plt.axvline(x=1.65, linestyle='--', color="black")
    plt.axvline(x=3.65, linestyle='--', color="black")
        
    plt.setp(plt.xticks()[1], fontsize=fontsize)
    plt.setp(plt.yticks()[1], fontsize=fontsize)
    plt.ylim((0, 0.8))
    plt.ylabel(r'\textit{TPA}', fontsize=fontsize)
    plt.xticks(np.arange(6)+barWidth/2, ['All Humans', '6 Humans', 'All Humans', '6 Humans', 'All Humans', '6 Humans'])
    # TODO: you may not need this and instead you can just have the legend code below
    plt.text(0.5, 0.75, "Real", fontsize=fontsize)
    plt.text(2.3, 0.75, "Held-out", fontsize=fontsize)
    plt.text(4.4, 0.75, "Simulated", fontsize=fontsize)
    
    # Create legend & Show graphic
    #plt.legend(fontsize=30, frameon=False, ncol=1)
    sns.despine(fig)
    plt.savefig(parent_dir + '/data/figures/study_TPA_rebuttal.pdf', format='pdf', bbox_inches='tight')
    #plt.show()