In [None]:
import math
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

import scienceplots

plt.style.use(["science", "grid"])

In [None]:
# logs_base_path = "../../logs/"
# logs_base_path = "/home/dc755/idiots/logs/"
logs_base_path = "logs/"
experiment_names = [
    # "mnist-gd-grokking-2",
    # "mnist-grokking-slower-2",
    # "mnist-grokking-slower",
    # "mnist-adamw",
    # "mnist-adamw-64",
    # "mnist-adamw-256",
    # "mnist-gf",
    # "mnist-gd",
    # "mnist-sgd",
    # "mnist-sgd-8",
    # "mnist-sgd-16",
    # "addition-gf",
    # "addition-adamw-mlp",
    # "division-gf-mlp",
    # "division-adamw-mlp",
    # "division-adamw-transformer",
    # "division-adamw-mlp-1",
]

for exp in [43, 44, 45, 46, 47, 48, 49, 50, 52, 53, 54, 55, 56, 58, 59, 60]:
    experiment_names.append(f"mnist-fixed-norm-gf-{exp}")

In [None]:
num_experiments = len(experiment_names)

fig, axs = plt.subplots(
    6,
    num_experiments,
    figsize=(2 + num_experiments * 4, 10),
    # sharex="col",
    sharey="row",
    squeeze=False,
    gridspec_kw={"height_ratios": [2, 1, 1, 1, 1, 2]},
)

for i, experiment_name in enumerate(experiment_names):
    checkpoint_dir = Path(logs_base_path, "results", f"{experiment_name}.json")

    with open(checkpoint_dir, "r") as json_file:
        df = pd.read_json(json_file)

    if not "step" in df.keys():
        df["step"] = df.index * 5

    ax1 = axs[0, i]
    ax2 = axs[1, i]
    ax3 = axs[2, i]
    ax4 = axs[3, i]
    ax5 = axs[4, i]
    ax6 = axs[5, i]

    sns.lineplot(data=df, x="step", y="training_acc", label="Transformer train", ax=ax1)
    sns.lineplot(data=df, x="step", y="test_acc", label="Transformer test", ax=ax1)
    sns.lineplot(data=df, x="step", y="svm_accuracy", label="SVM", ax=ax1)
    sns.lineplot(data=df, x="step", y="svm_train_accuracy", label="SVM train", ax=ax1)
    sns.lineplot(data=df, x="step", y="gp_accuracy", label="GP", ax=ax1)
    ax1.set(ylabel="Accuracy" if i == 0 else "", title=experiment_name)

    sns.lineplot(data=df, x="step", y="train_loss", label="Train loss", ax=ax2)
    sns.lineplot(data=df, x="step", y="test_loss", label="Test loss", ax=ax2)
    ax2.set(ylabel="Loss" if i == 0 else "")

    sns.lineplot(data=df, x="step", y="kernel_alignment", ax=ax3)
    ax3.set(ylabel="Kernel alignment" if i == 0 else "")

    sns.lineplot(data=df, x="step", y="dots", label="Threshold 1 (default)", ax=ax4)
    # sns.lineplot(data=df, x="step", y="dots_2", label="Threshold 2", ax=ax4)
    # sns.lineplot(data=df, x="step", y="dots_3", label="Perplexity", ax=ax4)
    ax4.set(ylabel="DOTS" if i == 0 else "")

    sns.lineplot(data=df, x="step", y="weight_norm", ax=ax5)
    ax5.set(ylabel="Weight norm" if i == 0 else "")

    eigen_data = []
    for step, eig in zip(df["step"], df["eigenvalues"]):
        for j, e in enumerate(eig):
            eigen_data.append({"step": step, "eigen_rank": j, "eigenvalue": e})
    eigen_df = pd.DataFrame(eigen_data)
    sns.lineplot(data=eigen_df, x="eigen_rank", y="eigenvalue", hue="step", ax=ax6)
    ax6.set(ylabel="Eigenvalues" if i == 0 else "", yscale="log")

    for ax in axs[:, i]:
        if ax.get_legend() is not None:
            if i == len(experiment_names) - 1:
                ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
            else:
                ax.legend().set_visible(False)

fig.tight_layout()
fig.savefig("logs/plots/plot.pdf", bbox_inches="tight")

In [None]:
# plot for experiment 1
result_base_path = Path("logs", "results")
experiments = [
    {
        "file": result_base_path / "division-adamw-mlp-1.json",
        "name": "2-layer MLP",
    },
    {
        "file": result_base_path / "division-adamw-mlp.json",
        "name": "3-layer MLP",
    },
    {
        "file": result_base_path / "division-adamw-transformer.json",
        "name": "Transformer",
    },
]

fig, axs = plt.subplots(
    3,
    len(experiments),
    figsize=(10, 4),
    sharex="col",
    squeeze=False,
    height_ratios=[2, 1, 1],
)

for i, experiment in enumerate(experiments):
    with open(experiment["file"], "r") as f:
        df = pd.read_json(f)

    for key, label in [
        ("training_acc", "NN train"),
        ("test_acc", "NN test"),
        ("svm_accuracy", "SVM"),
        ("gp_accuracy", "Kernel regression"),
    ]:
        sns.lineplot(data=df, x="step", y=key, label=label, ax=axs[0, i])
    axs[0, i].set(title=experiment["name"], ylabel="Accuracy", xlabel="Step")
    if i < len(experiments) - 1:
        axs[0, i].get_legend().remove()

    sns.lineplot(data=df, x="step", y="kernel_alignment", ax=axs[1, i])
    axs[1, i].set(ylabel="Kernel alignment", xlabel="Step")

    sns.lineplot(data=df, x="step", y="dots", ax=axs[2, i])
    axs[2, i].set(ylabel="DOTS", xlabel="Step")

# Share y-axis for the following rows
for j in [0, 1]:
    y_min = min(ax.get_ylim()[0] for ax in axs[j])
    y_max = max(ax.get_ylim()[1] for ax in axs[j])
    for i in range(len(experiments)):
        axs[j, i].set_ylim(y_min, y_max)

for j in range(len(axs)):
    for i in range(len(axs[0])):
        if i != 0:
            axs[j, i].set(ylabel="")

fig.tight_layout()

# fig.savefig("logs/plots/division-comparison.pdf", bbox_inches="tight")

In [None]:
# plot for mnist
result_base_path = Path("logs", "results")
experiments = [
    {
        "file": result_base_path / "mnist-adamw.json",
        "name": "MNIST",
    },
]

fig, axs = plt.subplots(
    3,
    len(experiments),
    figsize=(5, 4),
    sharex="col",
    squeeze=False,
    height_ratios=[2, 1, 1],
)

for i, experiment in enumerate(experiments):
    with open(experiment["file"], "r") as f:
        df = pd.read_json(f)

    for key, label in [
        ("training_acc", "MLP train"),
        ("test_acc", "MLP test"),
        ("svm_accuracy", "SVM"),
        ("gp_accuracy", "Kernel regression"),
    ]:
        sns.lineplot(data=df, x="step", y=key, label=label, ax=axs[0, i])
    axs[0, i].set(ylabel="Accuracy", xlabel="Step")
    if i < len(experiments) - 1:
        axs[0, i].get_legend().remove()

    sns.lineplot(data=df, x="step", y="kernel_alignment", ax=axs[1, i])
    axs[1, i].set(ylabel="Kernel alignment", xlabel="Step")

    sns.lineplot(data=df, x="step", y="dots", ax=axs[2, i])
    axs[2, i].set(ylabel="DOTS", xlabel="Step")

# Share y-axis for the following rows
for j in [0]:
    y_min = min(ax.get_ylim()[0] for ax in axs[j])
    y_max = max(ax.get_ylim()[1] for ax in axs[j])
    for i in range(len(experiments)):
        axs[j, i].set_ylim(y_min, y_max)

for j in range(len(axs)):
    for i in range(len(axs[0])):
        if i != 0:
            axs[j, i].set(ylabel="")

fig.tight_layout()

# fig.savefig("logs/plots/mnist.pdf", bbox_inches="tight")

In [None]:
# Ablate DOTS

# exp_name = "division-adamw-transformer"
# exp_name = "division-adamw-mlp"
# exp_name = "division-adamw-mlp-1"
# exp_name = "mnist-adamw"
exp_name = "division-gf-mlp"
checkpoint_dir = Path(logs_base_path, "results", f"{exp_name}.json")

with open(checkpoint_dir, "r") as json_file:
    df = pd.read_json(json_file)

# fig, axs = plt.subplots(
#     3,
#     1,
#     figsize=(5, 4),
#     sharex="col",
#     squeeze=False,
#     height_ratios=[1, 1, 1],
# )
# sns.lineplot(data=df, x="step", y="dots", ax=axs[0, 0])
# sns.lineplot(data=df, x="step", y="dots_2", ax=axs[1, 0])
# sns.lineplot(data=df, x="step", y="dots_3", ax=axs[2, 0])

# 3 x 2 grid
# 3 plots for DOTS, 1 large plot for singular values
grid = plt.GridSpec(3, 2, wspace=0.25, hspace=0.1, width_ratios=[3, 5])
fig = plt.figure(figsize=(8, 3.3))

# DOTS
ax1 = fig.add_subplot(grid[0, 0])
ax2 = fig.add_subplot(grid[1, 0])
ax3 = fig.add_subplot(grid[2, 0])
sns.lineplot(data=df, x="step", y="dots", ax=ax1)
sns.lineplot(data=df, x="step", y="dots_2", ax=ax2)
sns.lineplot(data=df, x="step", y="dots_3", ax=ax3)
ax1.set(ylabel="DOTS", xticklabels=[], xlabel=None)
ax2.set(ylabel="DOTS (2)", xticklabels=[], xlabel=None)
ax3.set(ylabel="Eigen perplexity", xlabel="Step")

# Singular values
ax4 = fig.add_subplot(grid[:, 1])
eigen_data = []
for step, eig in zip(df["step"], df["eigenvalues"]):
    for j, e in enumerate(eig):
        eigen_data.append({"Step": step, "eigen_rank": j, "eigenvalue": e})
eigen_df = pd.DataFrame(eigen_data)
sns.lineplot(data=eigen_df, x="eigen_rank", y="eigenvalue", hue="Step", ax=ax4)
ax4.set(yscale="log", ylabel="Values", xlabel="$i^{th}$ eigenvalue")

# fig.savefig("logs/plots/ablation-dots.pdf", bbox_inches="tight")