In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
# logs_base_path = "../../logs/"
# logs_base_path = "/home/dc755/idiots/logs/"
logs_base_path = "logs/results"

results = list(Path(logs_base_path).glob("mnist-fixed-norm-*.json"))

In [None]:
dfs = []
for i, results_file in enumerate(results):
    with open(results_file, "r") as f:
        df = pd.read_json(f)
    dfs.append(df.iloc[df["step"].idxmax()])

df = pd.concat(dfs, axis=1).T

fig, axs = plt.subplots(
    4,
    1,
    figsize=(6, 8),
    sharex="col",
    squeeze=False,
)

print(df.columns)

for key, name in [
    ("training_acc", "Training error"),
    ("test_acc", "Test error"),
    ("svm_train_accuracy", "SVM train error"),
    ("svm_accuracy", "SVM test error"),
    ("gp_accuracy", "GP error"),
]:
    sns.lineplot(
        x=df["weight_norm"], y=1 - df[key], label=name, ax=axs[0, 0], marker="o"
    )
axs[0, 0].set(xscale="log", ylabel="Error", xlabel="Weight norm")

for key, name in [
    ("train_loss", "Training loss"),
    ("test_loss", "Test loss"),
]:
    sns.lineplot(x=df["weight_norm"], y=df[key], label=name, ax=axs[1, 0], marker="o")
axs[1, 0].set(xscale="log", ylabel="Loss", xlabel="Weight norm", ylim=(-1e-3, 0.1))

sns.lineplot(x=df["weight_norm"], y=df["kernel_alignment"], ax=axs[2, 0], marker="o")
axs[2, 0].set(xscale="log", ylabel="Kernel alignment", xlabel="Weight norm")

sns.lineplot(x=df["weight_norm"], y=df["dots"], ax=axs[3, 0], marker="o")
axs[3, 0].set(xscale="log", ylabel="Dots", xlabel="Weight norm")

In [None]:
fig, axs = plt.subplots(
    2,
    num_experiments,
    figsize=(2 + num_experiments * 6, 4),
    sharex="col",
    sharey="row",
    squeeze=False,
)

for i, experiment_name in enumerate(experiment_names):
    checkpoint_dir = Path(logs_base_path, "results", f"{experiment_name}.json")

    with open(checkpoint_dir, "r") as json_file:
        df = pd.read_json(json_file)

    if not "step" in df.keys():
        df["step"] = df.index * 5

    ax1 = axs[0, i]
    ax2 = axs[1, i]

    sns.lineplot(data=df, x="relative_weight_norm", y="train_loss", ax=ax1)
    ax1.set(xlabel="Weight norm", ylabel="Train loss", title=experiment_name)

    sns.lineplot(data=df, x="relative_weight_norm", y="test_loss", ax=ax2)
    ax2.set(xlabel="Weight norm", ylabel="Test loss")

fig.tight_layout()
plt.show()