In [None]:
import os 
import pandas as pd 
import plotly.express as px 
import seaborn as sns 
import matplotlib.pyplot as plt
import numpy as np 


sns.set_theme(style="whitegrid")

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Computer Modern Roman']
plt.rcParams['text.usetex'] = True
plt.rcParams['font.size'] = 8
plt.rcParams['axes.labelsize'] = 8
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['legend.fontsize'] = 8
markersize = 3
margin_title_size = 8

plot_height = 0.7
line_width = 5.1

In [None]:
results = []
for root, dirs, files in os.walk("results/time_uci"):
    for file in files:
        if file.endswith(".csv"):
            experiment_name = os.path.basename(root)
            experiment_name = experiment_name.replace("=-1", "=None")
            print(experiment_name)
            metrics_path = os.path.join(root, file)
            result_df = pd.read_csv(metrics_path)
            results.append(result_df)

data = pd.concat(results)

In [None]:
data = pd.concat(results).astype(
    {
        "dataset": str, 
        "model": str, 
        'num_layers': int, 
        "dataset_dim": int, 
        "batch_size": int, 
        "num_iters": int, 
        "num_inducing": int, 
        "seed": int,
        "time": float, 
    }
).sort_values(
    by=['model', 'num_layers', 'seed', 'batch_size', 'dataset']
).query('num_iters == 100')

In [None]:
errorbar = ('sd', 1)

plot_data = data.rename(columns={
    "dataset": "Dataset",
    "model": "Model",
    "num_layers": "Number of Layers",
    "time": "Time (s)",
}).replace(
    {
        "Model": {
            "residual+spherical_harmonic_features": "Residual (IV)",
            "euclidean+inducing_points": "Euclidean (IL)",
        },
    }
).assign(
    Dataset=lambda x: x["Dataset"].str.capitalize(),
).assign(
    log_time=lambda x: np.log(x["Time (s)"])
).rename(
    columns={
        "log_time": "Log Time (s)",
    }
).assign(
    Dataset=lambda x: x["Dataset"] + "\nB=" + x["batch_size"].astype(str) + ", d=" + x["dataset_dim"].astype(str),
)


g = sns.FacetGrid(
    plot_data, 
    col="Dataset", 
    margin_titles=True, 
    sharey=False, 
    hue="Model",
    gridspec_kws={"wspace":0.25, "hspace": 0.0},
)
g.map(sns.lineplot, "Number of Layers", "Time (s)", marker="o", errorbar=errorbar, markersize=markersize)
g.add_legend(
    title="",
    loc="upper center",
    bbox_to_anchor=(0.5, 1.85),
    ncol=2,
)

g.set(xticks=[1, 2, 3, 4, 5])
g.set_axis_labels("\# Layers")
g.set_titles(col_template="{col_name}", size=margin_title_size)

for ax in g.axes.flat:
    ax.tick_params(axis='both', which='major', pad=-3)
    ax.tick_params(axis='y', pad=-4)

g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
g.figure.set_size_inches(line_width, plot_height)


# plt.savefig("./plots/uci-timing.pdf", bbox_inches='tight', pad_inches=0.0)