In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import scienceplots

plt.style.use(["science", "grid"])

In [None]:
# logs_base_path = "../../logs/"
# logs_base_path = "/home/dc755/idiots/logs/"
logs_base_path = "logs/results"

results = []

for exp in [43, 44, 45, 46, 47, 48, 49, 50, 52, 53, 54, 55, 56, 58, 59, 60]:
    results.append(Path(logs_base_path) / f"mnist-fixed-norm-gf-{exp}.json")

In [None]:
dfs = []
for i, results_file in enumerate(results):
    with open(results_file, "r") as f:
        df = pd.read_json(f)
    row = df.iloc[df["test_acc"].idxmax()].copy()
    row["dots_init"] = df.query("step == 0")["dots"].values[0]
    row["source"] = results_file
    dfs.append(row)

df = pd.concat(dfs, axis=1).T

fig, axs = plt.subplots(
    3,
    1,
    figsize=(5, 4),
    sharex=True,
    squeeze=False,
    height_ratios=[2, 1, 1],
)

ax1, ax2, ax3 = axs.flatten()

for key, name in [
    ("training_acc", "MLP (train)"),
    ("test_acc", "MLP (test)"),
    # ("svm_train_accuracy", "SVM train error"),
    ("svm_accuracy", "SVM"),
    ("gp_accuracy", "Kernel regression"),
]:
    sns.lineplot(
        x=df["weight_norm"],
        y=1 - df[key],
        label=name,
        ax=ax1,
        marker="o",
        markersize=4,
    )
ax1.set(xscale="log", ylabel="Error", xlabel="Weight norm")

sns.lineplot(
    x=df["weight_norm"], y=df["kernel_alignment"], ax=ax2, marker="o", markersize=4
)
ax2.set(xscale="log", ylabel="Kernel alignment", xlabel="Weight norm")

sns.lineplot(
    x=df["weight_norm"],
    y=df["dots_init"],
    ax=ax3,
    marker="o",
    label="Initial",
    markersize=4,
)
sns.lineplot(
    x=df["weight_norm"], y=df["dots"], ax=ax3, marker="o", label="Final", markersize=4
)
ax3.set(xscale="log", ylabel="DOTS", xlabel="Weight norm")

fig.tight_layout()

# plt.savefig("logs/plots/mnist-fixed-norm-gf.pdf", bbox_inches="tight")

In [None]:
df[["weight_norm", "source"]].query("weight_norm > 10").sort_values("weight_norm")