# Making all preference-capability scatter plots for the paper

In [None]:
import os
import pandas as pd
import plotly.graph_objs as go
import PyPDF2
from PIL import Image
from tqdm.notebook import tqdm

# fix mathjax showing up on first plot bug (disables mathjax though):
import plotly.io as pio   
pio.kaleido.scope.mathjax = None

In [None]:
PATH_0 = '../train-procgen/experiments/results-1000'
PATH_OUT = '../plots'
PATH_MAZE_OBJ = '../maze-objects'
os.makedirs(PATH_OUT, exist_ok=True)
WIP = 0  # change to 0 to do all plots and to not show figures

MODELS = [
# maze-5x5-with-backgrounds:
    'maze-5x5-with-backgrounds/black-line-yellow-gem-final-model/',
    'maze-5x5-with-backgrounds/blue-line-yellow-gem-final-model/',
    'maze-5x5-with-backgrounds/cyan-line-yellow-gem-final-model/',
    'maze-5x5-with-backgrounds/green-line-yellow-gem-final-model/',
    'maze-5x5-with-backgrounds/purple-line-yellow-gem-final-model/',
    'maze-5x5-with-backgrounds/red-line-yellow-gem-final-model/',
    'maze-5x5-with-backgrounds/white-line-yellow-gem-final-model/',
    'maze-5x5-with-backgrounds/yellow-line-yellow-gem-final-model/',

# line vs line:
    'maze-5x5-with-backgrounds/yellow-line-yellow-line-final-model/',
    'maze-5x5-with-backgrounds/red-line-green-line-final-model/',
    'maze-5x5-with-backgrounds/green-line-blue-line-final-model/',
    'maze-5x5-with-backgrounds/blue-line-red-line-final-model/',
    'maze-5x5-with-backgrounds/black-line-white-line-final-model/',
    'maze-5x5-with-backgrounds/cyan-line-purple-line-final-model/',
    'maze-5x5-with-backgrounds/purple-line-yellow-line-final-model/',
    'maze-5x5-with-backgrounds/yellow-line-cyan-line-final-model/',

# maze-5x5
    'maze-5x5/black-line-yellow-gem-final-model/',
    'maze-5x5/blue-line-yellow-gem-final-model/',
    'maze-5x5/cyan-line-yellow-gem-final-model/',
    'maze-5x5/green-line-yellow-gem-final-model/',
    'maze-5x5/purple-line-yellow-gem-final-model/',
    'maze-5x5/red-line-yellow-gem-final-model/',
    'maze-5x5/white-line-yellow-gem-final-model/',
    'maze-5x5/yellow-line-yellow-gem-final-model/',
    'maze-5x5/red-line-green-line-final-model/',
    'maze-5x5/green-line-red-line-final-model/',
    'maze-5x5/white-line-yellow-line-final-model/',
    'maze-5x5/black-gem-black-gem-final-model/',
    'maze-5x5/yellow-line-black-gem-final-model/',
    'maze-5x5/green-line-yellow-line-final-model/',

# maze-5x5-white-line
    'maze-5x5-white-line/red-line-green-line-final-model/',
    'maze-5x5-white-line/green-line-blue-line-final-model/',
    'maze-5x5-white-line/blue-line-red-line-final-model/',
    'maze-5x5-white-line/cyan-line-purple-line-final-model/',
    'maze-5x5-white-line/purple-line-yellow-line-final-model/',
    'maze-5x5-white-line/yellow-line-cyan-line-final-model/',
    'maze-5x5-white-line/white-line-yellow-line-final-model/',
    'maze-5x5-white-line/white-line-white-line-final-model/',
    'maze-5x5-white-line/white-line-white-gem-final-model/',
    'maze-5x5-white-line/yellow-line-white-gem-final-model/',
    'maze-5x5-white-line/red-line-white-gem-final-model/',

# maze-5x5-grey-background
    'maze-5x5-grey-background/black-line-yellow-gem-final-model/',
    'maze-5x5-grey-background/blue-line-yellow-gem-final-model/',
    'maze-5x5-grey-background/cyan-line-yellow-gem-final-model/',
    'maze-5x5-grey-background/green-line-yellow-gem-final-model/',
    'maze-5x5-grey-background/purple-line-yellow-gem-final-model/',
    'maze-5x5-grey-background/red-line-yellow-gem-final-model/',
    'maze-5x5-grey-background/white-line-yellow-gem-final-model/',
    'maze-5x5-grey-background/yellow-line-yellow-gem-final-model/',
    'maze-5x5-grey-background/red-line-green-line-final-model/',
    'maze-5x5-grey-background/green-line-white-line-final-model/',
    'maze-5x5-grey-background/white-line-red-line-final-model/',

# maze-5x5-grey-background-red-line
    'maze-5x5-grey-background-red-line/black-line-red-gem-final-model/',
    'maze-5x5-grey-background-red-line/blue-line-red-gem-final-model/',
    'maze-5x5-grey-background-red-line/cyan-line-red-gem-final-model/',
    'maze-5x5-grey-background-red-line/green-line-red-gem-final-model/',
    'maze-5x5-grey-background-red-line/purple-line-red-gem-final-model/',
    'maze-5x5-grey-background-red-line/red-line-red-gem-final-model/',
    'maze-5x5-grey-background-red-line/white-line-red-gem-final-model/',
    'maze-5x5-grey-background-red-line/yellow-line-red-gem-final-model/',
    'maze-5x5-grey-background-red-line/black-line-yellow-line-final-model/',
    'maze-5x5-grey-background-red-line/yellow-line-purple-line-final-model/',
    'maze-5x5-grey-background-red-line/purple-line-black-line-final-model/',

# maze-5x5-red-line
    'maze-5x5-red-line/black-line-red-gem-final-model/',
    'maze-5x5-red-line/blue-line-red-gem-final-model/',
    'maze-5x5-red-line/cyan-line-red-gem-final-model/',
    'maze-5x5-red-line/green-line-red-gem-final-model/',
    'maze-5x5-red-line/purple-line-red-gem-final-model/',
    'maze-5x5-red-line/red-line-red-gem-final-model/',
    'maze-5x5-red-line/white-line-red-gem-final-model/',
    'maze-5x5-red-line/yellow-line-red-gem-final-model/',
    'maze-5x5-red-line/red-line-yellow-line-final-model/',
    'maze-5x5-red-line/red-line-green-line-final-model/',
    'maze-5x5-red-line/green-line-blue-line-final-model/',
    'maze-5x5-red-line/blue-line-red-line-final-model/',
    'maze-5x5-red-line/yellow-line-cyan-line-final-model/',
    'maze-5x5-red-line/yellow-line-purple-line-final-model/',
]

if WIP:
    MODELS = MODELS[:1]

TRAIN_DICT = {
    'maze-5x5': 'yellow line black background',
    'maze-5x5-red-line': 'red line black background',
    'maze-5x5-grey-background-red-line': 'red line grey background',
    'maze-5x5-grey-background': 'yellow line grey background',
    'maze-5x5-white-line': 'white line black background',
    'maze-5x5-with-backgrounds': 'yellow line',
             }

BLACKLIST = ['2023-02-10__09-48-21__seed_8998']  # the single much longer run in maze-5x5
NR_RUNS = 1000

In [None]:
def parse_agent(agent):
    agent_path = os.path.join(PATH, agent)
    models = sorted([f for f in os.listdir(agent_path) if f.endswith('.csv')], key=lambda x: int(x.split('_')[1].split('.')[0]))
    rets = []
    for model in models:
        model_path = os.path.join(agent_path, model)
        rets.append(parse_model(model_path))
    return rets

In [None]:
def parse_model(model_path):
    df = pd.read_csv(model_path)
    mean_steps = df['steps'].mean()
    median_steps = df['steps'].median()
    cnt10 = sum(df['reward'] == 10.0)
    cntneg10 = sum(df['reward'] == -10.0)
    cnt0 = sum(df['reward'] == 0.0)
    return mean_steps, median_steps, cnt10, cntneg10, cnt0

In [None]:
for model in tqdm(MODELS):
    PATH = os.path.join(PATH_0, model)
    trained_on = TRAIN_DICT[PATH.split('/')[-3]]
    obj1 = ' '.join(PATH.split('/')[-2].split('-')[:2])
    obj2 = ' '.join(PATH.split('/')[-2].split('-')[2:4])
    agents = sorted(os.listdir(PATH))

    means = []
    medians = []
    cnt10s = []
    cntneg10s = []
    for agent in agents:
        if agent in BLACKLIST:
            continue
        rets = parse_agent(agent)
        means.append(rets[-1][0])
        medians.append(rets[-1][1])
        cnt10s.append(rets[-1][2])
        cntneg10s.append(rets[-1][3])
    means = [min(m, 100) for m in means]

    pct_pref = [cnt10s[i]/(cnt10s[i]+cntneg10s[i]) for i in range(len(cnt10s))]

    # make the figure
    fig = go.Figure(data=go.Scatter(
        x=means, 
        y=pct_pref, 
        mode='markers', 
    ))
    trained_on_short = ' '.join(trained_on.split(' ')[:2])

    if trained_on.endswith('grey background'):
        background_str = 'grey-background'
    elif trained_on.endswith('black background'):
        background_str = 'black-background'
    else:    
        background_str = 'textured-background'

    obj1_filename = f"pure_{obj1.replace(' ', '_')}{'_diag' if obj1.endswith('line') else ''}.png"
    obj2_filename = f"pure_{obj2.replace(' ', '_')}{'_diag' if obj2.endswith('line') else ''}.png"
    train_obj_filename = f"pure_{trained_on_short.replace(' ', '_')}{'_diag' if trained_on_short.endswith('line') else ''}.png"
    loaded_image1 = Image.open(os.path.join(PATH_MAZE_OBJ, background_str, obj1_filename))
    loaded_image2 = Image.open(os.path.join(PATH_MAZE_OBJ, background_str, obj2_filename))
    loaded_image_train = Image.open(os.path.join(PATH_MAZE_OBJ, background_str, train_obj_filename))
                                    
    img_size = loaded_image1.size[0]
    plot_size = 530
    plot_size = plot_size // img_size * img_size
    
    scaling_factor = 5
    size = img_size / plot_size * scaling_factor
    
    fig.add_layout_image(
        dict(
            source=loaded_image1,
            xref="paper",
            yref="paper",
            x=0.0025,
            y=1,
            sizex=size,
            sizey=size,
            xanchor="left",
            yanchor="top",
            layer="above",
            sizing="contain",
        )
    )
    
    fig.add_layout_image(
        dict(
            source=loaded_image2,
            xref="paper",
            yref="paper",
            x=0.0025,
            y=0.128,
            sizex=size,
            sizey=size,
            xanchor="left",
            yanchor="top",
            layer="above",
            sizing="contain",
        )
    )
    
    fig.update_layout(
        xaxis=dict(range=[0, 100], title='Mean episode length', tickfont=dict(size=24), title_font=dict(size=36)),
        yaxis=dict(range=[0, 1], title=f'Preferences', tickfont=dict(size=24), title_font=dict(size=36)),
        width=plot_size,
        height=plot_size,
        annotations=[
            dict(
                x=0.13,
                y=1,
                xref="paper",
                yref="paper",
                text=obj1,
                showarrow=False,
                font=dict(size=24)
            ),
            dict(
                x=0.13,
                y=0,
                xref="paper",
                yref="paper",
                text=obj2,
                showarrow=False,
                font=dict(size=24)
            )
        ],
        title={
            'text': f'Trained on {trained_on_short}',
            'x': 0.2,
            'y': 0.97,
            'font': dict(size=36)
        },
        margin=dict(l=0, r=0, t=53, b=60),
    )
    
    pdf_path = os.path.join(PATH_OUT, f"train-{trained_on.replace(' ', '-')}-test-{obj1.replace(' ', '-')}-{obj2.replace(' ', '-')}-scatter.pdf")
    fig.write_image(pdf_path)
    if WIP:
        fig.show()

    # cut the 2nd blank page and trim edges
    with open(pdf_path, 'rb') as pdf_file:
        pdf_reader = PyPDF2.PdfReader(pdf_file)
        pdf_writer = PyPDF2.PdfWriter()
        page = pdf_reader.pages[0]
        page.mediabox.lower_left = (0, 1)
        page.mediabox.upper_right = (390, 381)
        pdf_writer.add_page(page)

        with open(pdf_path, 'wb') as output_file:
            pdf_writer.write(output_file)