# Producing the OLS plot

In [None]:
import os
import pandas as pd
import plotly.graph_objects as go
import statsmodels.formula.api as smf
import plotly.express as px
import PyPDF2
from PIL import Image

# fix mathjax showing up on plot bug (disables mathjax though):
import plotly.io as pio   
pio.kaleido.scope.mathjax = None

In [None]:
PATH_0 = '../train-procgen/experiments/'
PATH_MAZE_OBJ = '../maze-objects'
PATH_OUT = '../plots/'
os.makedirs(PATH_OUT, exist_ok=True)

PATHS = [
    os.path.join(PATH_0, 'results-1000/maze-5x5/green-line-red-line-final-model/'),
    os.path.join(PATH_0, 'results-1000/maze-5x5/yellow-gem-red-line-final-model/'),
]

TRAIN_DICT = {
    'maze-5x5': 'yellow line black background',
    'maze-5x5-red-line': 'red line black background',
    'maze-5x5-grey-background-red-line': 'red line grey background',
    'maze-5x5-grey-background': 'yellow line grey background',
    'maze-5x5-white-line': 'white line black background',
    'maze-5x5-with-backgrounds': 'yellow line',
             }

PATH = PATHS[0]
trained_on = TRAIN_DICT[PATH.split('/')[-3]]
obj1 = ' '.join(PATH.split('/')[-2].split('-')[:2])
obj2 = ' '.join(PATH.split('/')[-2].split('-')[2:4])

BLACKLIST = ['2023-02-10__09-48-21__seed_8998']  # the single much longer run in maze-5x5
NR_RUNS = 1000

In [None]:
def parse_agent(path, agent):
    agent_path = os.path.join(path, agent)
    models = sorted([f for f in os.listdir(agent_path) if f.endswith('.csv')], key=lambda x: int(x.split('_')[1].split('.')[0]))
    rets = []
    for model in models[:10]:
        model_path = os.path.join(agent_path, model)
        rets.append(parse_model(model_path))
    return rets

In [None]:
def parse_model(model_path):
    df = pd.read_csv(model_path)
    mean_steps = df['steps'].mean()
    median_steps = df['steps'].median()
    cnt10 = sum(df['reward'] == 10.0)
    cntneg10 = sum(df['reward'] == -10.0)
    cnt0 = sum(df['reward'] == 0.0)
    return mean_steps, median_steps, cnt10, cntneg10, cnt0

In [None]:
for path in PATHS:
    agents = sorted(os.listdir(path))
    agent_names = [f'agent-{agent.split("_")[-1]}' for agent in agents]
    obj1 = '-'.join(path.split('/')[-2].split('-')[:2])
    obj2 = '-'.join(path.split('/')[-2].split('-')[2:4])
    col1 = f'{obj1}-{obj2}-mean-ep-len'
    col2 = f'{obj1}-{obj2}-median-ep-len'
    col3 = f'{obj1}-{obj2}-pref'
    
    means = []
    medians = []
    cnt10s = []
    cntneg10s = []
    cnt0s = []
    pref_pcts = []
    for agent in agents:
        if agent in BLACKLIST:
            continue
        rets = parse_agent(path, agent)
        mean, median, cnt10, cntneg10, cnt0 = rets[-1]
        means.append(mean)
        medians.append(median)
        cnt10s.append(cnt10)
        cntneg10s.append(cntneg10)
        cnt0s.append(cnt0)
        pref_pct = cnt10 / (cnt10 + cntneg10)
        pref_pcts.append(pref_pct)
        print(*rets[-1], agent, pref_pct)
    print()

In [None]:
df = pd.DataFrame()

prefs = []
for path in PATHS:
    agents = sorted(os.listdir(path))
    agent_names = [f'agent-{agent.split("_")[-1]}' for agent in agents]
    obj1 = '-'.join(path.split('/')[-2].split('-')[:2])
    obj2 = '-'.join(path.split('/')[-2].split('-')[2:4])
    col1 = f'{obj1}-{obj2}-mean-ep-len'
    col2 = f'{obj1}-{obj2}-median-ep-len'
    col3 = f'{obj1}-{obj2}-pref'
    prefs.append(col3)
    
    means = []
    medians = []
    cnt10s = []
    cntneg10s = []
    cnt0s = []
    pref_pcts = []
    for agent in agents:
        rets = parse_agent(path, agent)
        mean, median, cnt10, cntneg10, cnt0 = rets[-1]
        means.append(mean)
        medians.append(median)
        cnt10s.append(cnt10)
        cntneg10s.append(cntneg10)
        cnt0s.append(cnt0)
        pref_pct = cnt10 / (cnt10 + cntneg10)
        pref_pcts.append(pref_pct)
        agent_name = f'agent-{agent.split("_")[-1]}'
        
    data = {'means': means, 'medians': medians, 'pref_pcts': pref_pcts}
    temp_df = pd.DataFrame(data, index=agent_names)
    temp_df.columns = [col1, col2, col3]
    df = pd.concat([df, temp_df], axis=1)

print(df.head())

In [None]:
pref1 = f'"{prefs[1]}"'
pref2 = f'"{prefs[0]}"'

# Fit the OLS regression model
model = smf.ols(f'Q({pref1}) ~ Q({pref2})', data=df)
result = model.fit()

pref1 = pref1[1:-1]
pref2 = pref2[1:-1]

words = pref2.split("-")
x_title = f"Preference for {words[0]} {words[1]} vs {words[2]} {words[3]}"
words = pref1.split("-")
y_title = f"Preference for {words[0]} {words[1]} vs {words[2]} {words[3]}"

# Create a scatter plot
fig = px.scatter(df, x=pref2, y=pref1)

# Add OLS regression line
x_range = [df[pref2].min(), df[pref2].max()]
y_range = [result.predict({pref2: x})[0] for x in x_range]
line = go.Scatter(x=x_range, y=y_range, mode='lines', name=f"OLS: R^2={result.rsquared:.3f},<br>coef0={result.params[0]:.3f}, coef1={result.params[1]:.3f}")
fig.add_trace(line)

# manual adjustment needed for image names:
loaded_image1 = Image.open(os.path.join(PATH_MAZE_OBJ, 'black-background', 'pure_yellow_gem.png'))
loaded_image2 = Image.open(os.path.join(PATH_MAZE_OBJ, 'black-background', 'pure_red_line_diag.png'))
loaded_image3 = Image.open(os.path.join(PATH_MAZE_OBJ, 'black-background', 'pure_green_line_diag.png'))

img_size = loaded_image1.size[0]
plot_size = 530
plot_size = plot_size // img_size * img_size

scaling_factor = 5
size = img_size / plot_size * scaling_factor

# Set the legend position to top left
fig.update_layout(
    legend=dict(
        x=1,
        y=1,
        traceorder="normal",
        font=dict(
            family="sans-serif",
            size=18,
            color="black"
        ),
        bgcolor="White",
        bordercolor="Black",
        borderwidth=1,
        xanchor="right",
        yanchor="top"
    ),
    xaxis=dict(range=[0, 1], title='Preferences', tickfont=dict(size=24), title_font=dict(size=36)),
    yaxis=dict(range=[0, 1], title='Preferences', tickfont=dict(size=24), title_font=dict(size=36)),
    width=plot_size,
    height=plot_size,
    
)

fig.add_layout_image(
    dict(
        source=loaded_image1,
        xref="paper",
        yref="paper",
        x=0.0025,
        y=1,
        sizex=size,  # this scales the image width, adjust as needed
        sizey=size,  # this scales the image height, adjust as needed
        xanchor="left",  # 'left' makes sure the image starts from the left of the specified x position
        yanchor="top",  # 'top' makes sure the image starts from the top of the specified y position
        layer="above",
        sizing="contain",
    )
)

fig.add_layout_image(
    dict(
        source=loaded_image2,
        xref="paper",
        yref="paper",
        x=0.0025,
        y=0.12,
        sizex=size,  # this scales the image width, adjust as needed
        sizey=size,  # this scales the image height, adjust as needed
        xanchor="left",  # 'left' makes sure the image starts from the left of the specified x position
        yanchor="top",  # 'top' makes sure the image starts from the top of the specified y position
        layer="above",
        sizing="contain",
    )
)

fig.add_layout_image(
    dict(
        source=loaded_image3,
        xref="paper",
        yref="paper",
        x=0.8745,
        y=0.12,
        sizex=size,  # this scales the image width, adjust as needed
        sizey=size,  # this scales the image height, adjust as needed
        xanchor="left",  # 'left' makes sure the image starts from the left of the specified x position
        yanchor="top",  # 'top' makes sure the image starts from the top of the specified y position
        layer="above",
        sizing="contain",
    )
)

# Save and show the plot
pdf_path = os.path.join(PATH_OUT, f"train-{trained_on.replace(' ', '-')}-regress-{pref1}-{pref2}-scatter.pdf")
fig.write_image(pdf_path)
fig.show()


In [None]:
# cut the 2nd blank page and trim edges
with open(pdf_path, 'rb') as pdf_file:
    pdf_reader = PyPDF2.PdfReader(pdf_file)
    pdf_writer = PyPDF2.PdfWriter()
    page = pdf_reader.pages[0]
    page.mediabox.lower_left = (0, 8)
    page.mediabox.upper_right = (390, 360)
    pdf_writer.add_page(page)

    with open(pdf_path, 'wb') as output_file:
        pdf_writer.write(output_file)