In [None]:
import os 
import pandas as pd 
import plotly.express as px 
import seaborn as sns 
import matplotlib.pyplot as plt


sns.set_theme(style="whitegrid")

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Computer Modern Roman']
plt.rcParams['text.usetex'] = True
plt.rcParams['font.size'] = 9
plt.rcParams['axes.labelsize'] = 10
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9
plt.rcParams['legend.fontsize'] = 10
markersize = 3
margin_title_size = 9

plot_height = 0.7
line_width = 5.1


from fractions import Fraction
import re


def float_to_latex_fraction(x, limit_denominator=10):
    x = float(re.search(r'\{([^}]+)\}', x.get_text()).group(1).replace('−', '-'))
    frac = Fraction(x).limit_denominator(limit_denominator)
    if frac.denominator == 1:
        return f"${frac.numerator}$"
    return f"$\\frac{{{frac.numerator}}}{{{frac.denominator}}}$"


def set_yticklabels_as_fractions(ax):
    ax.set_yticklabels([float_to_latex_fraction(t) for t in ax.get_yticklabels()])

In [None]:
results = []
for root, dirs, files in os.walk("./results/synthetic/"):
    for file in files:
        if file.endswith(".csv"):
            experiment_name = os.path.basename(root)
            experiment_params = experiment_name.split("-")
            experiment_dict = {param.split("=")[0]: param.split("=")[1] for param in experiment_params}

            metrics_path = os.path.join(root, file)
            metrics_dict = pd.read_csv(metrics_path).to_dict(orient="records")[0]

            result_dict = experiment_dict | metrics_dict
            results.append(result_dict)

In [None]:
data = pd.DataFrame(results).astype(
    {
        "num_train": int, 
        "model_name": str, 
        "num_layers": int, 
        "seed": int, 
        "nlpd": float,
        "mse": float,
    }
).sort_values(by=['num_train', 'num_layers'])

In [None]:
data = data.rename(
    columns={
        "num_train": "Number of Training Points",
        "num_layers": "Number of Layers",
        "nlpd": "NLPD",
        "mse": "MSE",
        "model_name": "Model",
    }
).replace(
    {
        "Model": {
            'residual+spherical_harmonic_features': "Residual (IV)", 
            'residual+inducing_points': "Residual (PI)", 
            'euclidean_with_geometric_input+inducing_points': "Baseline",
            'residual+hodge+spherical_harmonic_features': 'Hodge (IV)',
        },
    }
).query(
    "`Number of Training Points` == 800"
)

In [None]:
errorbar = ('sd', 1)

g = sns.FacetGrid(
    data, 
    col="Number of Training Points", 
    hue="Model", 
    margin_titles=True,
    gridspec_kws={"wspace":0.20, "hspace": 0.0},
)
g.map(sns.lineplot, "Number of Layers", "NLPD", marker="o", errorbar=errorbar, markersize=markersize)
g.add_legend(
    title="",
    loc="upper center",
    bbox_to_anchor=(0.5, 1.75),
    ncol=4,
    fontsize=9,
)

g.set(xticks=[1, 2, 3, 4, 5])
g.set_axis_labels("\# Layers")
g.set_titles(col_template="N = {col_name}", size=9)

g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
g.figure.set_size_inches(line_width, plot_height)

# set_yticklabels_as_fractions(g.axes[0, 0])
for ax in g.axes.flat:
    ax.tick_params(axis='x', which='major', pad=-3)
    ax.tick_params(axis='y', pad=-4)

# plt.savefig("./plots/synthetic-nlpd_vs_num_layers_and_num_training-all_models-sd1.pdf", bbox_inches='tight')

In [None]:
errorbar = ('sd', 1)

g = sns.FacetGrid(
    data, 
    col="Number of Training Points", 
    hue="Model", 
    margin_titles=True,
    gridspec_kws={"wspace":0.19, "hspace": 0.0},
)
g.map(sns.lineplot, "Number of Layers", "MSE", marker="o", errorbar=errorbar, markersize=markersize)
g.add_legend(
    title="",
    loc="upper center",
    bbox_to_anchor=(0.5, 1.75),
    ncol=4,
    fontsize=9,
)

g.set(xticks=[1, 2, 3, 4, 5])
g.set_axis_labels("\# Layers")
g.set_titles(col_template="N = {col_name}", size=9)
g.set(ylim=(0, None))
g.set(yticks=[0.0000, 0.0025, 0.0050])

g.figure.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0)
g.figure.set_size_inches(line_width, plot_height)

# set_yticklabels_as_fractions(g.axes[0, 0])
for ax in g.axes.flat:
    ax.tick_params(axis='x', which='major', pad=-3)
    ax.tick_params(axis='y', pad=-4)

# plt.savefig("./plots/synthetic-mse_vs_num_layers_and_num_training-all_models-sd1.pdf", bbox_inches='tight')