In [None]:
import math
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
logs_base_path = "../../logs/"
logs_base_path = "logs/"
experiment_names = ["mnist-gd-grokking", "mnist-grokking-slower"]

In [None]:
num_experiments = len(experiment_names)

fig, axs = plt.subplots(
    3,
    num_experiments,
    figsize=(2 + num_experiments * 6, 6),
    sharex="col",
    sharey="row",
    squeeze=False,
    gridspec_kw={"height_ratios": [3, 1, 1]},
)

for i, experiment_name in enumerate(experiment_names):
    checkpoint_dir = Path(logs_base_path, "results", f"{experiment_name}.json")

    with open(checkpoint_dir, "r") as json_file:
        df = pd.read_json(json_file)

    if not "step" in df.keys():
        df["step"] = df.index * 5

    ax1 = axs[0, i]
    ax2 = axs[1, i]
    ax3 = axs[2, i]

    sns.lineplot(data=df, x="step", y="training_acc", label="Transformer train", ax=ax1)
    sns.lineplot(data=df, x="step", y="test_acc", label="Transformer test", ax=ax1)
    sns.lineplot(data=df, x="step", y="svm_accuracy", label="SVM", ax=ax1)
    sns.lineplot(data=df, x="step", y="gp_accuracy", label="GP", ax=ax1)
    ax1.set(ylabel="Accuracy" if i == 0 else "", title=experiment_name)

    sns.lineplot(data=df, x="step", y="kernel_alignment", ax=ax2)
    ax2.set(ylabel="Kernel alignment" if i == 0 else "")

    sns.lineplot(data=df, x="step", y="dots", ax=ax3)
    ax3.set(ylabel="DOTS" if i == 0 else "")

    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()

    if i == len(experiment_names) - 1:
        ax1.legend(
            lines + lines2,
            labels + labels2,
            loc="center left",
            bbox_to_anchor=(1.15, 0.5),
        )
    else:
        ax1.legend().set_visible(False)

fig.tight_layout()