In [9]:
import pandas as pd
import plotly.express as px
import os

In [2]:
root_dir = "./results/"
dataset_names = os.listdir(root_dir)
metrics_fpaths = [os.path.join(root_dir, dataset_dir, "metrics/100.csv") for dataset_dir in dataset_names]


In [12]:
dfs = {}
for dataset_name, fpath in zip(dataset_names, metrics_fpaths):
    df = pd.read_csv(fpath, index_col=0)
    df = df.rename(index=lambda x: 'ValenTwin' if x.startswith('valentwin') else x)
    df = df.rename(index=lambda x: 'ALITE' if x.startswith('alite') and "simcse" not in x else x)
    df = df.rename(index=lambda x: 'ALITE-simcse' if x.startswith('alite') and "simcse" in x else x)
    df = df.rename(index=lambda x: 'Starmie-FT' if x.startswith('starmie') and "ft" in x and "simcse" not in x else x)
    df = df.rename(index=lambda x: 'Starmie-PT' if x.startswith('starmie') and "ft" not in x and "simcse" not in x else x)
    df = df.rename(index=lambda x: 'Starmie-FT-simcse' if x.startswith('starmie') and "ft" in x and "simcse" in x else x)
    df = df.rename(index=lambda x: 'Starmie-PT-simcse' if x.startswith('starmie') and "ft" not in x and "simcse" in x else x)
    df = df.rename(index=lambda x: 'DeepJoin' if x.startswith('deepjoin') and "simcse" not in x else x)
    df = df.rename(index=lambda x: 'DeepJoin-simcse' if x.startswith('deepjoin') and "simcse" in x else x)
    df = df[["recall_at_sizeof_ground_truth"]]
    # df.columns = pd.MultiIndex.from_product([[dataset_name], df.columns])
    df["integration_set"] = dataset_name
    df["method"] = df.index.str.replace("-simcse", "")
    df["plm"] = df.index.map(lambda x: "SimCSE" if "simcse" in x or "ValenTwin" in x else "Non-SimCSE")
    dfs[dataset_name] = df

# Concatenate all DataFrames along the columns
concatenated_df = pd.concat(dfs.values(), axis=0)
concatenated_df


Unnamed: 0,recall_at_sizeof_ground_truth,integration_set,method,plm
ALITE,0.268775,academic_papers,ALITE,Non-SimCSE
ALITE-simcse,0.501976,academic_papers,ALITE,SimCSE
DeepJoin,0.521739,academic_papers,DeepJoin,Non-SimCSE
DeepJoin-simcse,0.652174,academic_papers,DeepJoin,SimCSE
Starmie-FT,0.652174,academic_papers,Starmie-FT,Non-SimCSE
...,...,...,...,...
Starmie-FT,0.820513,cihr,Starmie-FT,Non-SimCSE
Starmie-FT-simcse,0.815705,cihr,Starmie-FT,SimCSE
Starmie-PT-simcse,0.842949,cihr,Starmie-PT,SimCSE
Starmie-PT,0.822115,cihr,Starmie-PT,Non-SimCSE


In [13]:
fig = px.box(concatenated_df, x="method", y="recall_at_sizeof_ground_truth", color="plm",
             labels={"method": "Method", "plm": "PLM", "recall_at_sizeof_ground_truth": "Recall@k"},
             height=400, width=600)
fig.update_layout(
    font_family="Times Roman",
    font_color="black",
    font_size=16,
    plot_bgcolor='white',
    legend=dict(orientation="h",yanchor="bottom",
    y=1.02,
    xanchor="right",
    x=0.75),
    margin = {'l':0,'r':0,'t':0,'b':0},
)
fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    gridcolor='lightgrey'
)
fig.update_yaxes(
    mirror=True,
    ticks='outside',
    tickformat=".2f",
    showline=True,
    linecolor='black',
    gridcolor='lightgrey'
)
fig.update_traces(boxmean=True)
fig.update_xaxes(categoryorder='array', categoryarray= ['ALITE', 'Starmie-PT', 'Starmie-FT', 'DeepJoin', 'ValenTwin'])

fig.show()
