In [None]:
import os 
import pandas as pd 
import plotly.express as px 
import seaborn as sns 
import matplotlib.pyplot as plt


sns.set_theme(style="whitegrid")

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Computer Modern Roman']
plt.rcParams['text.usetex'] = True
plt.rcParams['font.size'] = 9
plt.rcParams['axes.labelsize'] = 10
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9
plt.rcParams['legend.fontsize'] = 10
markersize = 3
margin_title_size = 9

plot_height = 0.7
line_width = 5.1


from fractions import Fraction
import re


def float_to_latex_fraction(x, limit_denominator=10):
    x = float(re.search(r'\{([^}]+)\}', x.get_text()).group(1).replace('−', '-'))
    frac = Fraction(x).limit_denominator(limit_denominator)
    if frac.denominator == 1:
        return f"${frac.numerator}$"
    return f"$\\frac{{{frac.numerator}}}{{{frac.denominator}}}$"


def set_yticklabels_as_fractions(ax):
    ax.set_yticklabels([float_to_latex_fraction(t) for t in ax.get_yticklabels()])

In [None]:
results = []
for root, dirs, files in os.walk("results/uci/"):
    for file in files:
        if file.endswith(".csv"):
            experiment_name = os.path.basename(root)
            experiment_name = experiment_name.replace("=-1", "=None")
            print(experiment_name)
            experiment_params = experiment_name.split("-")
            experiment_dict = {param.split("=")[0]: param.split("=")[1] for param in experiment_params}

            metrics_path = os.path.join(root, file)
            metrics_dict = pd.read_csv(metrics_path).to_dict(orient="records")[0]

            result_dict = experiment_dict | metrics_dict
            results.append(result_dict)

In [None]:
data = pd.DataFrame(results).astype(
    {
        'dataset_name': str, 
        "model_name": str, 
        "num_layers": int, 
        "seed": int, 
        # "num_iters": int, 
        "mse": float, 
        "nlpd": float,
    }
).dropna(subset=["num_iters"]).astype(
    {
        "num_iters": int
    }
).sort_values(
    by=['model_name', 'num_layers', 'seed']
).replace(
    "None", -1
).drop(
    columns=["Unnamed: 0"]
).query(
    "num_iters == 5000 & seed in [0, 1, 2, 3, 4] & kernel_max_ell == -1 & (dataset_name not in ['kin8mn', 'power'] | batch_size == '1000')"
).drop_duplicates().query((
    "model_name in ['euclidean+inducing_points', 'residual+spherical_harmonic_features']"
    "& dataset_name in ['yacht', 'energy', 'concrete', 'kin8mn', 'power']"
))

dataset_to_batch_size = {
    "yacht": 277, 
    "energy": 691, 
    "concrete": 927, 
    "kin8mn": 1000, 
    "power": 1000, 
}

dataset_to_dimension = {
    "yacht": 6, 
    "energy": 8, 
    "concrete": 8, 
    "kin8mn": 8, 
    "power": 4, 
}

data = data.assign(
    batch_size = lambda x: x["dataset_name"].map(dataset_to_batch_size),
    dimension = lambda x: x["dataset_name"].map(dataset_to_dimension),
).sort_values(
    by=['batch_size', 'dataset_name']
)

In [None]:
errorbar = ('sd', 1)

plot_data = data.rename(columns={
    "dataset_name": "Dataset",
    "model_name": "Model",
    "num_layers": "Number of Layers",
    "nlpd": "NLPD",
    "mse": "MSE"
}).replace(
    {
        "Model": {
            "residual+spherical_harmonic_features": "Residual (IV)",
            "euclidean+inducing_points": "Euclidean (PI)",
        },
    }
).assign(
    Dataset=lambda x: x["Dataset"].str.capitalize()
).assign(
    Dataset=lambda x: x["Dataset"] + "\nB=" + x["batch_size"].astype(str) + ", D=" + x["dimension"].astype(str),
)

g = sns.FacetGrid(
    plot_data, 
    col="Dataset", 
    margin_titles=False, 
    sharey=False, 
    hue="Model",
    gridspec_kws={"wspace":0.25, "hspace": 0.0}
)

g.map(sns.lineplot, "Number of Layers", "NLPD", marker="o", errorbar=errorbar, markersize=markersize)
g.add_legend(
    title="",
    loc="upper center",
    bbox_to_anchor=(0.5, 1.95),
    ncol=2,
)

g.set(xticks=[1, 2, 3, 4, 5])
g.set_axis_labels("\# Layers")
g.set_titles(col_template="{col_name}", size=9)

# g.axes[0, 4].set_yticks([-0.1, 0.0])

for ax in g.axes.flat:
    set_yticklabels_as_fractions(ax)
    ax.tick_params(axis='both', which='major', pad=-3)
    ax.tick_params(axis='y', pad=-4)


g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
g.figure.set_size_inches(line_width, plot_height)


# plt.tight_layout()
# plt.savefig("./plots/uci-nlpd_vs_num_layers-euclidean_and_residual-sd1-size_optimised.pdf", bbox_inches='tight', pad_inches=0.0)